{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 基于机器学习的pyborker量化交易策略"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<diskcache.core.Cache at 0x27a18d81190>"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import pybroker\n",
    "from pybroker.ext.data import AKShare\n",
    "from pybroker import ExecContext, StrategyConfig, Strategy\n",
    "from pybroker.data import DataSource\n",
    "import matplotlib.pyplot as plt\n",
    "from datetime import datetime\n",
    "import riskfolio as rp\n",
    "import akshare as ak\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import sqlite3\n",
    "import datetime\n",
    "import seaborn as sns\n",
    "from scipy.stats import pearsonr, spearmanr\n",
    "import talib\n",
    "from pybroker.vect import cross\n",
    "\n",
    "#正常显示画图时出现的中文和负号\n",
    "from pylab import mpl\n",
    "\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "from sklearn.model_selection import train_test_split \n",
    "from sklearn.tree import DecisionTreeRegressor\n",
    "from sklearn.metrics import r2_score\n",
    "from sklearn.metrics import mean_absolute_error\n",
    "from sklearn.linear_model import LinearRegression\n",
    "from sklearn.ensemble import GradientBoostingRegressor\n",
    "import lightgbm as lgb\n",
    "\n",
    "mpl.rcParams['font.sans-serif']=['SimHei']\n",
    "mpl.rcParams['axes.unicode_minus']=False\n",
    "\n",
    "akshare = AKShare()\n",
    "\n",
    "pybroker.enable_data_source_cache('akshare')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 股票预测的因子构建\n",
    "\n",
    "pip install lightgbm -i https://mirrors.aliyun.com/pypi/simple"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "conn=sqlite3.connect(r'I:\\量化金融\\stock_2018.db')\n",
    "stock_daily1=pd.read_sql(\"select * from stock_daily where 股票代码<'003000.SZ'\",con=conn)\n",
    "stock_daily1[\"交易日期\"]=pd.to_datetime(stock_daily1[\"交易日期\"].astype(str))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Index(['index', '交易日期', '股票代码', '股票简称', '开盘价', '最高价', '最低价', '收盘价', '成交量(手)',\n",
       "       '成交额(千元)', '换手率(%)', '量比', '市盈率(静态)', '市盈率(TTM)', '市盈率(动态)', '市净率',\n",
       "       '市销率', '市销率(TTM)', '股息率(%)', '股息率(TTM)(%)', '总股本(万股)', '流通股本(万股)',\n",
       "       '总市值(万元)', '流通市值(万元)'],\n",
       "      dtype='object')"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "stock_daily1.columns"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "stock_daily1.columns=['index', \"date\",\"symbol\", '股票简称', \"open\",\"high\",\"low\",\"close\",\"volume\",\n",
    "       '成交额(千元)', '换手率(%)', '量比', '市盈率(静态)', '市盈率(TTM)', '市盈率(动态)', '市净率',\n",
    "       '市销率', '市销率(TTM)', '股息率(%)', '股息率(TTM)(%)', '总股本(万股)', '流通股本(万股)',\n",
    "       '总市值(万元)', '流通市值(万元)']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "stock_daily1_d=stock_daily1.drop([\"index\",\"股票简称\"],axis=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "#stock_daily1_d"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def func0(x):\n",
    "    return x.pct_change().shift(-1)\n",
    "\n",
    "stock_daily1_d[\"return_s1\"]=stock_daily1_d.groupby(\"symbol\", group_keys=False).close.apply(func0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_bb(stock_data):\n",
    "\n",
    "    SMA20_close=talib.SMA(stock_data, timeperiod=20)\n",
    "\n",
    "    high, mid, low = talib.BBANDS(stock_data, timeperiod=20)\n",
    "\n",
    "    x,y,z=talib.MACD(stock_data,fastperiod=12, slowperiod=26, signalperiod=9)\n",
    "\n",
    "    dongliang=talib.MOM(stock_data, timeperiod=10)\n",
    "\n",
    "    rsi_6=talib.RSI(stock_data, timeperiod=6)\n",
    "    rsi_24=talib.RSI(stock_data, timeperiod=24)\n",
    "\n",
    "    up = talib.MAX(stock_data, 20)\n",
    "    down = talib.MIN(stock_data, 20)\n",
    "\n",
    "    return_30=stock_data.pct_change(30)\n",
    "    return_5=stock_data.pct_change(5)\n",
    "    return_10=stock_data.pct_change(10)\n",
    "\n",
    "\n",
    "    df=pd.concat([SMA20_close,high,low,x,y,z,dongliang,rsi_6,rsi_24,up,down,return_30,return_5,return_10],axis=1)\n",
    "    df.columns=[\"SMA20_close\",\"b_high\",\"b_low\",\"MACD_x\",\"MACD_y\",\"MACD_z\",\"dong10\",\"rsi_6\",\"rsi_24\",\"up\",\"down\",\"return_30\",\"return_5\",\"return_10\"]\n",
    "\n",
    "    return df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "z=stock_daily1_d.groupby(\"symbol\", group_keys=False).close.apply(compute_bb)\n",
    "stock_daily1_d=stock_daily1_d.join(z)\n",
    "stock_daily1_d[\"close-o\"]=stock_daily1_d[\"close\"]-stock_daily1_d[\"open\"]\n",
    "stock_daily1_d[\"high-l\"]=stock_daily1_d[\"high\"]-stock_daily1_d[\"low\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "#stock_daily1_d"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 基于lightgbm的股票预测\n",
    "\n",
    "* 格点调参"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Timestamp('2023-02-17 00:00:00')"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "stock_daily1_d.date.max()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Timestamp('2018-01-02 00:00:00')"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "stock_daily1_d.date.min()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "xy=stock_daily1_d[stock_daily1_d.date<datetime.datetime(2021,1,1)].iloc[:,2:].dropna()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "xy_x=xy.drop(\"return_s1\",axis=1)\n",
    "xy_y=xy[\"return_s1\"]\n",
    "x1,x2,y1,y2=train_test_split(xy_x,xy_y,test_size=0.7)#分割数据出训练集与测试集，0.7是两者行数的比例"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.0016471867933682827\n"
     ]
    }
   ],
   "source": [
    "clf = LinearRegression()\n",
    "#clf = GradientBoostingRegressor()\n",
    "clf = clf.fit(x1,y1)\n",
    "print(r2_score(y2,clf.predict(x2)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.029291 seconds.\n",
      "You can set `force_col_wise=true` to remove the overhead.\n",
      "[LightGBM] [Info] Total Bins 9180\n",
      "[LightGBM] [Info] Number of data points in the train set: 171108, number of used features: 36\n",
      "[LightGBM] [Info] Start training from score 0.000342\n",
      "0.014361957641128353\n"
     ]
    }
   ],
   "source": [
    "# 创建一个LGBMRegressor模型，目标函数为回归，叶子节点数为31，学习率为0.1，估计器个数为30\n",
    "gbm = lgb.LGBMRegressor(objective='regression',num_leaves=31,learning_rate=0.1,n_estimators=30)\n",
    "gbm.fit(x1,y1)\n",
    "y_pred = gbm.predict(x2)\n",
    "print(r2_score(y2,y_pred))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**交叉验证的工作原理**\n",
    "在每次迭代中：\n",
    "1. **数据分割**：将训练数据（`x1`, `y1`）随机分成3份，其中2份作为训练集，1份作为验证集。\n",
    "2. **模型训练与评估**：使用当前参数组合在训练集上训练模型，并在验证集上评估性能。\n",
    "3. **重复迭代**：重复上述过程3次，每次使用不同的验证集。\n",
    "4. **结果汇总**：最终返回3次评估结果的平均值作为该参数组合的性能指标。\n",
    "\n",
    "\n",
    "**为什么使用交叉验证？**\n",
    "1. **更可靠的评估**：单一验证集可能受数据划分的随机性影响，交叉验证通过多次平均减少这种偏差。\n",
    "2. **充分利用数据**：所有训练数据都有机会作为验证集，提高模型评估的稳定性。\n",
    "3. **防止过拟合**：帮助选择在不同数据子集上都表现稳定的参数，避免模型对特定数据过拟合。\n",
    "\n",
    "\n",
    "**`cv=3` 的具体示例**\n",
    "假设你的训练数据有900个样本，`cv=3` 会将数据分为3份（每份300个样本）：\n",
    "\n",
    "- **第一次迭代**：使用第1+2份（600个样本）训练，第3份（300个样本）验证。\n",
    "- **第二次迭代**：使用第1+3份训练，第2份验证。\n",
    "- **第三次迭代**：使用第2+3份训练，第1份验证。\n",
    "\n",
    "最终，每个参数组合的性能是这3次验证分数的平均值。\n",
    "\n",
    "**注意事项**\n",
    "- **计算成本**：折数越多，计算时间越长（例如 `cv=10` 比 `cv=3` 慢约3倍）。\n",
    "- **数据量**：数据量较小时，建议使用较高的折数（如 `cv=5` 或 `cv=10`）；数据量较大时，`cv=3` 通常足够。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.027274 seconds.\n",
      "You can set `force_col_wise=true` to remove the overhead.\n",
      "[LightGBM] [Info] Total Bins 9180\n",
      "[LightGBM] [Info] Number of data points in the train set: 114072, number of used features: 36\n",
      "[LightGBM] [Info] Start training from score 0.000273\n",
      "[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.020374 seconds.\n",
      "You can set `force_col_wise=true` to remove the overhead.\n",
      "[LightGBM] [Info] Total Bins 9180\n",
      "[LightGBM] [Info] Number of data points in the train set: 114072, number of used features: 36\n",
      "[LightGBM] [Info] Start training from score 0.000400\n",
      "[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.019933 seconds.\n",
      "You can set `force_col_wise=true` to remove the overhead.\n",
      "[LightGBM] [Info] Total Bins 9180\n",
      "[LightGBM] [Info] Number of data points in the train set: 114072, number of used features: 36\n",
      "[LightGBM] [Info] Start training from score 0.000352\n",
      "[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.020125 seconds.\n",
      "You can set `force_col_wise=true` to remove the overhead.\n",
      "[LightGBM] [Info] Total Bins 9180\n",
      "[LightGBM] [Info] Number of data points in the train set: 114072, number of used features: 36\n",
      "[LightGBM] [Info] Start training from score 0.000273\n",
      "[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.020047 seconds.\n",
      "You can set `force_col_wise=true` to remove the overhead.\n",
      "[LightGBM] [Info] Total Bins 9180\n",
      "[LightGBM] [Info] Number of data points in the train set: 114072, number of used features: 36\n",
      "[LightGBM] [Info] Start training from score 0.000400\n",
      "[LightGBM] [Info] Auto-choosing row-wise multi-threading, the overhead of testing was 0.018690 seconds.\n",
      "You can set `force_row_wise=true` to remove the overhead.\n",
      "And if memory is not enough, you can set `force_col_wise=true`.\n",
      "[LightGBM] [Info] Total Bins 9180\n",
      "[LightGBM] [Info] Number of data points in the train set: 114072, number of used features: 36\n",
      "[LightGBM] [Info] Start training from score 0.000352\n",
      "[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.029364 seconds.\n",
      "You can set `force_col_wise=true` to remove the overhead.\n",
      "[LightGBM] [Info] Total Bins 9180\n",
      "[LightGBM] [Info] Number of data points in the train set: 114072, number of used features: 36\n",
      "[LightGBM] [Info] Start training from score 0.000273\n",
      "[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.019787 seconds.\n",
      "You can set `force_col_wise=true` to remove the overhead.\n",
      "[LightGBM] [Info] Total Bins 9180\n",
      "[LightGBM] [Info] Number of data points in the train set: 114072, number of used features: 36\n",
      "[LightGBM] [Info] Start training from score 0.000400\n",
      "[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.020622 seconds.\n",
      "You can set `force_col_wise=true` to remove the overhead.\n",
      "[LightGBM] [Info] Total Bins 9180\n",
      "[LightGBM] [Info] Number of data points in the train set: 114072, number of used features: 36\n",
      "[LightGBM] [Info] Start training from score 0.000352\n",
      "[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.017508 seconds.\n",
      "You can set `force_col_wise=true` to remove the overhead.\n",
      "[LightGBM] [Info] Total Bins 9180\n",
      "[LightGBM] [Info] Number of data points in the train set: 114072, number of used features: 36\n",
      "[LightGBM] [Info] Start training from score 0.000273\n",
      "[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.017923 seconds.\n",
      "You can set `force_col_wise=true` to remove the overhead.\n",
      "[LightGBM] [Info] Total Bins 9180\n",
      "[LightGBM] [Info] Number of data points in the train set: 114072, number of used features: 36\n",
      "[LightGBM] [Info] Start training from score 0.000400\n",
      "[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.015791 seconds.\n",
      "You can set `force_col_wise=true` to remove the overhead.\n",
      "[LightGBM] [Info] Total Bins 9180\n",
      "[LightGBM] [Info] Number of data points in the train set: 114072, number of used features: 36\n",
      "[LightGBM] [Info] Start training from score 0.000352\n",
      "[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.017300 seconds.\n",
      "You can set `force_col_wise=true` to remove the overhead.\n",
      "[LightGBM] [Info] Total Bins 9180\n",
      "[LightGBM] [Info] Number of data points in the train set: 114072, number of used features: 36\n",
      "[LightGBM] [Info] Start training from score 0.000273\n",
      "[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.029600 seconds.\n",
      "You can set `force_col_wise=true` to remove the overhead.\n",
      "[LightGBM] [Info] Total Bins 9180\n",
      "[LightGBM] [Info] Number of data points in the train set: 114072, number of used features: 36\n",
      "[LightGBM] [Info] Start training from score 0.000400\n",
      "[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.019295 seconds.\n",
      "You can set `force_col_wise=true` to remove the overhead.\n",
      "[LightGBM] [Info] Total Bins 9180\n",
      "[LightGBM] [Info] Number of data points in the train set: 114072, number of used features: 36\n",
      "[LightGBM] [Info] Start training from score 0.000352\n",
      "[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.018509 seconds.\n",
      "You can set `force_col_wise=true` to remove the overhead.\n",
      "[LightGBM] [Info] Total Bins 9180\n",
      "[LightGBM] [Info] Number of data points in the train set: 114072, number of used features: 36\n",
      "[LightGBM] [Info] Start training from score 0.000273\n",
      "[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.016861 seconds.\n",
      "You can set `force_col_wise=true` to remove the overhead.\n",
      "[LightGBM] [Info] Total Bins 9180\n",
      "[LightGBM] [Info] Number of data points in the train set: 114072, number of used features: 36\n",
      "[LightGBM] [Info] Start training from score 0.000400\n",
      "[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.016278 seconds.\n",
      "You can set `force_col_wise=true` to remove the overhead.\n",
      "[LightGBM] [Info] Total Bins 9180\n",
      "[LightGBM] [Info] Number of data points in the train set: 114072, number of used features: 36\n",
      "[LightGBM] [Info] Start training from score 0.000352\n",
      "[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.018128 seconds.\n",
      "You can set `force_col_wise=true` to remove the overhead.\n",
      "[LightGBM] [Info] Total Bins 9180\n",
      "[LightGBM] [Info] Number of data points in the train set: 114072, number of used features: 36\n",
      "[LightGBM] [Info] Start training from score 0.000273\n",
      "[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.019819 seconds.\n",
      "You can set `force_col_wise=true` to remove the overhead.\n",
      "[LightGBM] [Info] Total Bins 9180\n",
      "[LightGBM] [Info] Number of data points in the train set: 114072, number of used features: 36\n",
      "[LightGBM] [Info] Start training from score 0.000400\n",
      "[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.026725 seconds.\n",
      "You can set `force_col_wise=true` to remove the overhead.\n",
      "[LightGBM] [Info] Total Bins 9180\n",
      "[LightGBM] [Info] Number of data points in the train set: 114072, number of used features: 36\n",
      "[LightGBM] [Info] Start training from score 0.000352\n",
      "[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.018985 seconds.\n",
      "You can set `force_col_wise=true` to remove the overhead.\n",
      "[LightGBM] [Info] Total Bins 9180\n",
      "[LightGBM] [Info] Number of data points in the train set: 114072, number of used features: 36\n",
      "[LightGBM] [Info] Start training from score 0.000273\n",
      "[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.023659 seconds.\n",
      "You can set `force_col_wise=true` to remove the overhead.\n",
      "[LightGBM] [Info] Total Bins 9180\n",
      "[LightGBM] [Info] Number of data points in the train set: 114072, number of used features: 36\n",
      "[LightGBM] [Info] Start training from score 0.000400\n",
      "[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.019433 seconds.\n",
      "You can set `force_col_wise=true` to remove the overhead.\n",
      "[LightGBM] [Info] Total Bins 9180\n",
      "[LightGBM] [Info] Number of data points in the train set: 114072, number of used features: 36\n",
      "[LightGBM] [Info] Start training from score 0.000352\n",
      "[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.021081 seconds.\n",
      "You can set `force_col_wise=true` to remove the overhead.\n",
      "[LightGBM] [Info] Total Bins 9180\n",
      "[LightGBM] [Info] Number of data points in the train set: 114072, number of used features: 36\n",
      "[LightGBM] [Info] Start training from score 0.000273\n",
      "[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.019600 seconds.\n",
      "You can set `force_col_wise=true` to remove the overhead.\n",
      "[LightGBM] [Info] Total Bins 9180\n",
      "[LightGBM] [Info] Number of data points in the train set: 114072, number of used features: 36\n",
      "[LightGBM] [Info] Start training from score 0.000400\n",
      "[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.019522 seconds.\n",
      "You can set `force_col_wise=true` to remove the overhead.\n",
      "[LightGBM] [Info] Total Bins 9180\n",
      "[LightGBM] [Info] Number of data points in the train set: 114072, number of used features: 36\n",
      "[LightGBM] [Info] Start training from score 0.000352\n",
      "[LightGBM] [Info] Auto-choosing row-wise multi-threading, the overhead of testing was 0.023118 seconds.\n",
      "You can set `force_row_wise=true` to remove the overhead.\n",
      "And if memory is not enough, you can set `force_col_wise=true`.\n",
      "[LightGBM] [Info] Total Bins 9180\n",
      "[LightGBM] [Info] Number of data points in the train set: 114072, number of used features: 36\n",
      "[LightGBM] [Info] Start training from score 0.000273\n",
      "[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.019910 seconds.\n",
      "You can set `force_col_wise=true` to remove the overhead.\n",
      "[LightGBM] [Info] Total Bins 9180\n",
      "[LightGBM] [Info] Number of data points in the train set: 114072, number of used features: 36\n",
      "[LightGBM] [Info] Start training from score 0.000400\n",
      "[LightGBM] [Info] Auto-choosing row-wise multi-threading, the overhead of testing was 0.016351 seconds.\n",
      "You can set `force_row_wise=true` to remove the overhead.\n",
      "And if memory is not enough, you can set `force_col_wise=true`.\n",
      "[LightGBM] [Info] Total Bins 9180\n",
      "[LightGBM] [Info] Number of data points in the train set: 114072, number of used features: 36\n",
      "[LightGBM] [Info] Start training from score 0.000352\n",
      "[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.019306 seconds.\n",
      "You can set `force_col_wise=true` to remove the overhead.\n",
      "[LightGBM] [Info] Total Bins 9180\n",
      "[LightGBM] [Info] Number of data points in the train set: 114072, number of used features: 36\n",
      "[LightGBM] [Info] Start training from score 0.000273\n",
      "[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.018600 seconds.\n",
      "You can set `force_col_wise=true` to remove the overhead.\n",
      "[LightGBM] [Info] Total Bins 9180\n",
      "[LightGBM] [Info] Number of data points in the train set: 114072, number of used features: 36\n",
      "[LightGBM] [Info] Start training from score 0.000400\n",
      "[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.017492 seconds.\n",
      "You can set `force_col_wise=true` to remove the overhead.\n",
      "[LightGBM] [Info] Total Bins 9180\n",
      "[LightGBM] [Info] Number of data points in the train set: 114072, number of used features: 36\n",
      "[LightGBM] [Info] Start training from score 0.000352\n",
      "[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.017684 seconds.\n",
      "You can set `force_col_wise=true` to remove the overhead.\n",
      "[LightGBM] [Info] Total Bins 9180\n",
      "[LightGBM] [Info] Number of data points in the train set: 114072, number of used features: 36\n",
      "[LightGBM] [Info] Start training from score 0.000273\n",
      "[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.020604 seconds.\n",
      "You can set `force_col_wise=true` to remove the overhead.\n",
      "[LightGBM] [Info] Total Bins 9180\n",
      "[LightGBM] [Info] Number of data points in the train set: 114072, number of used features: 36\n",
      "[LightGBM] [Info] Start training from score 0.000400\n",
      "[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.018039 seconds.\n",
      "You can set `force_col_wise=true` to remove the overhead.\n",
      "[LightGBM] [Info] Total Bins 9180\n",
      "[LightGBM] [Info] Number of data points in the train set: 114072, number of used features: 36\n",
      "[LightGBM] [Info] Start training from score 0.000352\n",
      "[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.019317 seconds.\n",
      "You can set `force_col_wise=true` to remove the overhead.\n",
      "[LightGBM] [Info] Total Bins 9180\n",
      "[LightGBM] [Info] Number of data points in the train set: 114072, number of used features: 36\n",
      "[LightGBM] [Info] Start training from score 0.000273\n",
      "[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.018795 seconds.\n",
      "You can set `force_col_wise=true` to remove the overhead.\n",
      "[LightGBM] [Info] Total Bins 9180\n",
      "[LightGBM] [Info] Number of data points in the train set: 114072, number of used features: 36\n",
      "[LightGBM] [Info] Start training from score 0.000400\n",
      "[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.019897 seconds.\n",
      "You can set `force_col_wise=true` to remove the overhead.\n",
      "[LightGBM] [Info] Total Bins 9180\n",
      "[LightGBM] [Info] Number of data points in the train set: 114072, number of used features: 36\n",
      "[LightGBM] [Info] Start training from score 0.000352\n",
      "[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.021199 seconds.\n",
      "You can set `force_col_wise=true` to remove the overhead.\n",
      "[LightGBM] [Info] Total Bins 9180\n",
      "[LightGBM] [Info] Number of data points in the train set: 114072, number of used features: 36\n",
      "[LightGBM] [Info] Start training from score 0.000273\n",
      "[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.016270 seconds.\n",
      "You can set `force_col_wise=true` to remove the overhead.\n",
      "[LightGBM] [Info] Total Bins 9180\n",
      "[LightGBM] [Info] Number of data points in the train set: 114072, number of used features: 36\n",
      "[LightGBM] [Info] Start training from score 0.000400\n",
      "[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.019047 seconds.\n",
      "You can set `force_col_wise=true` to remove the overhead.\n",
      "[LightGBM] [Info] Total Bins 9180\n",
      "[LightGBM] [Info] Number of data points in the train set: 114072, number of used features: 36\n",
      "[LightGBM] [Info] Start training from score 0.000352\n",
      "[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.021053 seconds.\n",
      "You can set `force_col_wise=true` to remove the overhead.\n",
      "[LightGBM] [Info] Total Bins 9180\n",
      "[LightGBM] [Info] Number of data points in the train set: 114072, number of used features: 36\n",
      "[LightGBM] [Info] Start training from score 0.000273\n",
      "[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.018501 seconds.\n",
      "You can set `force_col_wise=true` to remove the overhead.\n",
      "[LightGBM] [Info] Total Bins 9180\n",
      "[LightGBM] [Info] Number of data points in the train set: 114072, number of used features: 36\n",
      "[LightGBM] [Info] Start training from score 0.000400\n",
      "[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.019187 seconds.\n",
      "You can set `force_col_wise=true` to remove the overhead.\n",
      "[LightGBM] [Info] Total Bins 9180\n",
      "[LightGBM] [Info] Number of data points in the train set: 114072, number of used features: 36\n",
      "[LightGBM] [Info] Start training from score 0.000352\n",
      "[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.017157 seconds.\n",
      "You can set `force_col_wise=true` to remove the overhead.\n",
      "[LightGBM] [Info] Total Bins 9180\n",
      "[LightGBM] [Info] Number of data points in the train set: 114072, number of used features: 36\n",
      "[LightGBM] [Info] Start training from score 0.000273\n",
      "[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.019624 seconds.\n",
      "You can set `force_col_wise=true` to remove the overhead.\n",
      "[LightGBM] [Info] Total Bins 9180\n",
      "[LightGBM] [Info] Number of data points in the train set: 114072, number of used features: 36\n",
      "[LightGBM] [Info] Start training from score 0.000400\n",
      "[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.018525 seconds.\n",
      "You can set `force_col_wise=true` to remove the overhead.\n",
      "[LightGBM] [Info] Total Bins 9180\n",
      "[LightGBM] [Info] Number of data points in the train set: 114072, number of used features: 36\n",
      "[LightGBM] [Info] Start training from score 0.000352\n",
      "[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.020049 seconds.\n",
      "You can set `force_col_wise=true` to remove the overhead.\n",
      "[LightGBM] [Info] Total Bins 9180\n",
      "[LightGBM] [Info] Number of data points in the train set: 114072, number of used features: 36\n",
      "[LightGBM] [Info] Start training from score 0.000273\n",
      "[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.020099 seconds.\n",
      "You can set `force_col_wise=true` to remove the overhead.\n",
      "[LightGBM] [Info] Total Bins 9180\n",
      "[LightGBM] [Info] Number of data points in the train set: 114072, number of used features: 36\n",
      "[LightGBM] [Info] Start training from score 0.000400\n",
      "[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.020218 seconds.\n",
      "You can set `force_col_wise=true` to remove the overhead.\n",
      "[LightGBM] [Info] Total Bins 9180\n",
      "[LightGBM] [Info] Number of data points in the train set: 114072, number of used features: 36\n",
      "[LightGBM] [Info] Start training from score 0.000352\n",
      "[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.019751 seconds.\n",
      "You can set `force_col_wise=true` to remove the overhead.\n",
      "[LightGBM] [Info] Total Bins 9180\n",
      "[LightGBM] [Info] Number of data points in the train set: 114072, number of used features: 36\n",
      "[LightGBM] [Info] Start training from score 0.000273\n",
      "[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.020296 seconds.\n",
      "You can set `force_col_wise=true` to remove the overhead.\n",
      "[LightGBM] [Info] Total Bins 9180\n",
      "[LightGBM] [Info] Number of data points in the train set: 114072, number of used features: 36\n",
      "[LightGBM] [Info] Start training from score 0.000400\n",
      "[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.029645 seconds.\n",
      "You can set `force_col_wise=true` to remove the overhead.\n",
      "[LightGBM] [Info] Total Bins 9180\n",
      "[LightGBM] [Info] Number of data points in the train set: 114072, number of used features: 36\n",
      "[LightGBM] [Info] Start training from score 0.000352\n",
      "[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.026424 seconds.\n",
      "You can set `force_col_wise=true` to remove the overhead.\n",
      "[LightGBM] [Info] Total Bins 9180\n",
      "[LightGBM] [Info] Number of data points in the train set: 171108, number of used features: 36\n",
      "[LightGBM] [Info] Start training from score 0.000342\n",
      "{'learning_rate': 0.1, 'n_estimators': 50, 'num_leaves': 31}\n"
     ]
    }
   ],
   "source": [
    "from sklearn.model_selection import GridSearchCV\n",
    "\n",
    "param_grid = {\n",
    "    'num_leaves': [31, 50],\n",
    "    'learning_rate': [0.05, 0.1, 0.15],\n",
    "    'n_estimators': [20, 30, 50]\n",
    "}\n",
    "gbm = lgb.LGBMRegressor(objective='regression')\n",
    "grid_search = GridSearchCV(gbm, param_grid, cv=3)\n",
    "grid_search.fit(x1,y1)\n",
    "print(grid_search.best_params_)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.015561255775650817\n"
     ]
    }
   ],
   "source": [
    "best_model = grid_search.best_estimator_\n",
    "y_pred = best_model.predict(x2)\n",
    "print(r2_score(y2,y_pred))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.003189569666169456\n"
     ]
    }
   ],
   "source": [
    "xy_t=stock_daily1_d[(stock_daily1_d.date>datetime.datetime(2021,1,1))&(stock_daily1_d.date<datetime.datetime(2021,5,1))].iloc[:,2:].dropna()\n",
    "xy_x=xy_t.drop(\"return_s1\",axis=1)\n",
    "xy_y=xy_t[\"return_s1\"]\n",
    "y_pred = best_model.predict(xy_x)\n",
    "print(r2_score(xy_y,y_pred))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([ 0.0016841 ,  0.00032934,  0.00093547, ..., -0.00053129,\n",
       "        0.00057716,  0.00047847])"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "y_pred "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 在pybroker上使用单日收益率预测模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "y_pred=pd.Series(y_pred,index=xy_x.index,name=\"pred\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "730        0.001684\n",
       "731        0.000329\n",
       "732        0.000935\n",
       "733        0.000918\n",
       "734        0.000882\n",
       "             ...   \n",
       "1726342   -0.001794\n",
       "1726343   -0.002777\n",
       "1726344   -0.000531\n",
       "1726345    0.000577\n",
       "1726346    0.000478\n",
       "Name: pred, Length: 64898, dtype: float64"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "y_pred"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "#stock_daily1_d[(stock_daily1_d.date>datetime.datetime(2021,1,1))&(stock_daily1_d.date<datetime.datetime(2021,5,1))].iloc[:,0:7]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "pyb_data_pe=stock_daily1_d[(stock_daily1_d.date>datetime.datetime(2021,1,1))&(stock_daily1_d.date<datetime.datetime(2021,5,1))].iloc[:,0:7]\n",
    "pyb_data_pe = pd.concat([pyb_data_pe, y_pred], axis=1)\n",
    "pyb_data_pe.fillna(0,inplace=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>date</th>\n",
       "      <th>symbol</th>\n",
       "      <th>open</th>\n",
       "      <th>high</th>\n",
       "      <th>low</th>\n",
       "      <th>close</th>\n",
       "      <th>volume</th>\n",
       "      <th>pred</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>730</th>\n",
       "      <td>2021-01-04</td>\n",
       "      <td>000001.SZ</td>\n",
       "      <td>2121.0359</td>\n",
       "      <td>2121.0359</td>\n",
       "      <td>2047.7436</td>\n",
       "      <td>2065.5114</td>\n",
       "      <td>1554216.43</td>\n",
       "      <td>0.001684</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>731</th>\n",
       "      <td>2021-01-05</td>\n",
       "      <td>000001.SZ</td>\n",
       "      <td>2043.3016</td>\n",
       "      <td>2052.1855</td>\n",
       "      <td>1976.6722</td>\n",
       "      <td>2017.7603</td>\n",
       "      <td>1821352.10</td>\n",
       "      <td>0.000329</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>732</th>\n",
       "      <td>2021-01-06</td>\n",
       "      <td>000001.SZ</td>\n",
       "      <td>2007.7659</td>\n",
       "      <td>2172.1184</td>\n",
       "      <td>1998.8820</td>\n",
       "      <td>2172.1184</td>\n",
       "      <td>1934945.12</td>\n",
       "      <td>0.000935</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>733</th>\n",
       "      <td>2021-01-07</td>\n",
       "      <td>000001.SZ</td>\n",
       "      <td>2167.6765</td>\n",
       "      <td>2218.7590</td>\n",
       "      <td>2135.4723</td>\n",
       "      <td>2209.8751</td>\n",
       "      <td>1584185.30</td>\n",
       "      <td>0.000918</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>734</th>\n",
       "      <td>2021-01-08</td>\n",
       "      <td>000001.SZ</td>\n",
       "      <td>2209.8751</td>\n",
       "      <td>2232.0849</td>\n",
       "      <td>2144.3562</td>\n",
       "      <td>2204.3226</td>\n",
       "      <td>1195473.22</td>\n",
       "      <td>0.000882</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1730540</th>\n",
       "      <td>2021-04-26</td>\n",
       "      <td>002999.SZ</td>\n",
       "      <td>12.3500</td>\n",
       "      <td>12.6800</td>\n",
       "      <td>12.2300</td>\n",
       "      <td>12.3200</td>\n",
       "      <td>54335.98</td>\n",
       "      <td>0.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1730541</th>\n",
       "      <td>2021-04-27</td>\n",
       "      <td>002999.SZ</td>\n",
       "      <td>12.5800</td>\n",
       "      <td>12.6300</td>\n",
       "      <td>11.8800</td>\n",
       "      <td>11.9400</td>\n",
       "      <td>57535.60</td>\n",
       "      <td>0.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1730542</th>\n",
       "      <td>2021-04-28</td>\n",
       "      <td>002999.SZ</td>\n",
       "      <td>11.9900</td>\n",
       "      <td>12.0100</td>\n",
       "      <td>11.7200</td>\n",
       "      <td>11.8400</td>\n",
       "      <td>40535.97</td>\n",
       "      <td>0.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1730543</th>\n",
       "      <td>2021-04-29</td>\n",
       "      <td>002999.SZ</td>\n",
       "      <td>12.0000</td>\n",
       "      <td>12.1900</td>\n",
       "      <td>11.7700</td>\n",
       "      <td>11.7900</td>\n",
       "      <td>49342.27</td>\n",
       "      <td>0.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1730544</th>\n",
       "      <td>2021-04-30</td>\n",
       "      <td>002999.SZ</td>\n",
       "      <td>11.7900</td>\n",
       "      <td>11.8400</td>\n",
       "      <td>11.5100</td>\n",
       "      <td>11.5400</td>\n",
       "      <td>37447.64</td>\n",
       "      <td>0.000000</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>111632 rows × 8 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "              date     symbol       open       high        low      close  \\\n",
       "730     2021-01-04  000001.SZ  2121.0359  2121.0359  2047.7436  2065.5114   \n",
       "731     2021-01-05  000001.SZ  2043.3016  2052.1855  1976.6722  2017.7603   \n",
       "732     2021-01-06  000001.SZ  2007.7659  2172.1184  1998.8820  2172.1184   \n",
       "733     2021-01-07  000001.SZ  2167.6765  2218.7590  2135.4723  2209.8751   \n",
       "734     2021-01-08  000001.SZ  2209.8751  2232.0849  2144.3562  2204.3226   \n",
       "...            ...        ...        ...        ...        ...        ...   \n",
       "1730540 2021-04-26  002999.SZ    12.3500    12.6800    12.2300    12.3200   \n",
       "1730541 2021-04-27  002999.SZ    12.5800    12.6300    11.8800    11.9400   \n",
       "1730542 2021-04-28  002999.SZ    11.9900    12.0100    11.7200    11.8400   \n",
       "1730543 2021-04-29  002999.SZ    12.0000    12.1900    11.7700    11.7900   \n",
       "1730544 2021-04-30  002999.SZ    11.7900    11.8400    11.5100    11.5400   \n",
       "\n",
       "             volume      pred  \n",
       "730      1554216.43  0.001684  \n",
       "731      1821352.10  0.000329  \n",
       "732      1934945.12  0.000935  \n",
       "733      1584185.30  0.000918  \n",
       "734      1195473.22  0.000882  \n",
       "...             ...       ...  \n",
       "1730540    54335.98  0.000000  \n",
       "1730541    57535.60  0.000000  \n",
       "1730542    40535.97  0.000000  \n",
       "1730543    49342.27  0.000000  \n",
       "1730544    37447.64  0.000000  \n",
       "\n",
       "[111632 rows x 8 columns]"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pyb_data_pe"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "#pyb_data_pe"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "count    111632.000000\n",
       "mean          0.000292\n",
       "std           0.002206\n",
       "min          -0.062534\n",
       "25%           0.000000\n",
       "50%           0.000000\n",
       "75%           0.000606\n",
       "max           0.089635\n",
       "Name: pred, dtype: float64"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pyb_data_pe.pred.describe()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array(['000001.SZ', '000002.SZ', '000004.SZ', '000005.SZ', '000006.SZ',\n",
       "       '000007.SZ', '000008.SZ', '000009.SZ', '000010.SZ', '000011.SZ'],\n",
       "      dtype=object)"
      ]
     },
     "execution_count": 54,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pyb_data_pe.symbol.unique()[0:10]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "def hold_long(ctx):\n",
    "    if not ctx.long_pos():\n",
    "        # Buy if the next bar is predicted to have a positive return:\n",
    "        if ctx.pred[-1] > 0.:\n",
    "            ctx.buy_shares = ctx.calc_target_shares(0.01)\n",
    "    else:\n",
    "        # Sell if the next bar is predicted to have a negative return:\n",
    "        if ctx.pred[-1] < 0.:\n",
    "            ctx.sell_shares = ctx.calc_target_shares(0.01)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Backtesting: 2021-01-01 00:00:00 to 2021-05-01 00:00:00\n",
      "\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Test split: 2021-01-04 00:00:00 to 2021-04-30 00:00:00\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0% (0 of 79) |                         | Elapsed Time: 0:00:00 ETA:  --:--:--\n",
      "  1% (1 of 79) |                         | Elapsed Time: 0:00:00 ETA:   0:00:18\n",
      " 13% (11 of 79) |###                     | Elapsed Time: 0:00:00 ETA:   0:00:03\n",
      " 26% (21 of 79) |######                  | Elapsed Time: 0:00:00 ETA:   0:00:01\n",
      " 39% (31 of 79) |#########               | Elapsed Time: 0:00:00 ETA:   0:00:01\n",
      " 51% (41 of 79) |############            | Elapsed Time: 0:00:01 ETA:   0:00:00\n",
      " 64% (51 of 79) |###############         | Elapsed Time: 0:00:01 ETA:   0:00:00\n",
      " 77% (61 of 79) |##################      | Elapsed Time: 0:00:01 ETA:   0:00:00\n",
      " 89% (71 of 79) |#####################   | Elapsed Time: 0:00:01 ETA:   0:00:00\n",
      "100% (79 of 79) |########################| Elapsed Time: 0:00:01 ETA:  00:00:00\n",
      "100% (79 of 79) |########################| Elapsed Time: 0:00:01 Time:  0:00:01\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Finished backtest: 0:00:02\n"
     ]
    }
   ],
   "source": [
    "pybroker.register_columns('pred')\n",
    "config = StrategyConfig(initial_cash=10000000)\n",
    "strategy = Strategy(pyb_data_pe, '2021-01-01', '2021-05-01',config)\n",
    "strategy.add_execution(hold_long, pyb_data_pe.symbol.unique()[250:350])\n",
    "result = strategy.backtest()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>name</th>\n",
       "      <th>value</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>trade_count</td>\n",
       "      <td>847.00</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>initial_market_value</td>\n",
       "      <td>10000000.00</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>end_market_value</td>\n",
       "      <td>10054134.95</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>total_pnl</td>\n",
       "      <td>43682.91</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>unrealized_pnl</td>\n",
       "      <td>10452.04</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                   name        value\n",
       "0           trade_count       847.00\n",
       "1  initial_market_value  10000000.00\n",
       "2      end_market_value  10054134.95\n",
       "3             total_pnl     43682.91\n",
       "4        unrealized_pnl     10452.04"
      ]
     },
     "execution_count": 34,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "result.metrics_df.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 轮动策略"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.2"
      ]
     },
     "execution_count": 35,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "stock_n=5\n",
    "config = StrategyConfig(max_long_positions=stock_n)\n",
    "pybroker.param('target_size', 1 / stock_n)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [],
   "source": [
    "def rand_ld(ctxs: dict[str, ExecContext]):\n",
    "    symbols=[]\n",
    "    ld=[]\n",
    "    for ctx in ctxs.values():\n",
    "        symbols.append(ctx.symbol)\n",
    "        ld.append(ctx.pred[-1])\n",
    "    \n",
    "    v_rank=pd.DataFrame({\"symbol\":symbols,\"ld\":ld}).sort_values(by=\"ld\",ascending=False)\n",
    "    pybroker.param('symbols', v_rank.head(stock_n).symbol.values)\n",
    "\n",
    "def rotate(ctx: ExecContext):\n",
    "    if ctx.long_pos():\n",
    "        if ctx.symbol not in pybroker.param('symbols'):\n",
    "            ctx.sell_all_shares()\n",
    "    else:\n",
    "        ctx.buy_limit_price = ctx.close[-1] * 1.095\n",
    "        target_size = pybroker.param('target_size')\n",
    "        ctx.buy_shares = ctx.calc_target_shares(target_size)\n",
    "        ctx.score = ctx.pred[-1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 100,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>date</th>\n",
       "      <th>symbol</th>\n",
       "      <th>open</th>\n",
       "      <th>high</th>\n",
       "      <th>low</th>\n",
       "      <th>close</th>\n",
       "      <th>volume</th>\n",
       "      <th>pred</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>730</th>\n",
       "      <td>2021-01-04</td>\n",
       "      <td>000001.SZ</td>\n",
       "      <td>2121.0359</td>\n",
       "      <td>2121.0359</td>\n",
       "      <td>2047.7436</td>\n",
       "      <td>2065.5114</td>\n",
       "      <td>1554216.43</td>\n",
       "      <td>0.000274</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>731</th>\n",
       "      <td>2021-01-05</td>\n",
       "      <td>000001.SZ</td>\n",
       "      <td>2043.3016</td>\n",
       "      <td>2052.1855</td>\n",
       "      <td>1976.6722</td>\n",
       "      <td>2017.7603</td>\n",
       "      <td>1821352.10</td>\n",
       "      <td>-0.000640</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>732</th>\n",
       "      <td>2021-01-06</td>\n",
       "      <td>000001.SZ</td>\n",
       "      <td>2007.7659</td>\n",
       "      <td>2172.1184</td>\n",
       "      <td>1998.8820</td>\n",
       "      <td>2172.1184</td>\n",
       "      <td>1934945.12</td>\n",
       "      <td>0.000760</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>733</th>\n",
       "      <td>2021-01-07</td>\n",
       "      <td>000001.SZ</td>\n",
       "      <td>2167.6765</td>\n",
       "      <td>2218.7590</td>\n",
       "      <td>2135.4723</td>\n",
       "      <td>2209.8751</td>\n",
       "      <td>1584185.30</td>\n",
       "      <td>0.000760</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>734</th>\n",
       "      <td>2021-01-08</td>\n",
       "      <td>000001.SZ</td>\n",
       "      <td>2209.8751</td>\n",
       "      <td>2232.0849</td>\n",
       "      <td>2144.3562</td>\n",
       "      <td>2204.3226</td>\n",
       "      <td>1195473.22</td>\n",
       "      <td>0.001234</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1730540</th>\n",
       "      <td>2021-04-26</td>\n",
       "      <td>002999.SZ</td>\n",
       "      <td>12.3500</td>\n",
       "      <td>12.6800</td>\n",
       "      <td>12.2300</td>\n",
       "      <td>12.3200</td>\n",
       "      <td>54335.98</td>\n",
       "      <td>0.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1730541</th>\n",
       "      <td>2021-04-27</td>\n",
       "      <td>002999.SZ</td>\n",
       "      <td>12.5800</td>\n",
       "      <td>12.6300</td>\n",
       "      <td>11.8800</td>\n",
       "      <td>11.9400</td>\n",
       "      <td>57535.60</td>\n",
       "      <td>0.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1730542</th>\n",
       "      <td>2021-04-28</td>\n",
       "      <td>002999.SZ</td>\n",
       "      <td>11.9900</td>\n",
       "      <td>12.0100</td>\n",
       "      <td>11.7200</td>\n",
       "      <td>11.8400</td>\n",
       "      <td>40535.97</td>\n",
       "      <td>0.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1730543</th>\n",
       "      <td>2021-04-29</td>\n",
       "      <td>002999.SZ</td>\n",
       "      <td>12.0000</td>\n",
       "      <td>12.1900</td>\n",
       "      <td>11.7700</td>\n",
       "      <td>11.7900</td>\n",
       "      <td>49342.27</td>\n",
       "      <td>0.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1730544</th>\n",
       "      <td>2021-04-30</td>\n",
       "      <td>002999.SZ</td>\n",
       "      <td>11.7900</td>\n",
       "      <td>11.8400</td>\n",
       "      <td>11.5100</td>\n",
       "      <td>11.5400</td>\n",
       "      <td>37447.64</td>\n",
       "      <td>0.000000</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>111632 rows × 8 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "              date     symbol       open       high        low      close  \\\n",
       "730     2021-01-04  000001.SZ  2121.0359  2121.0359  2047.7436  2065.5114   \n",
       "731     2021-01-05  000001.SZ  2043.3016  2052.1855  1976.6722  2017.7603   \n",
       "732     2021-01-06  000001.SZ  2007.7659  2172.1184  1998.8820  2172.1184   \n",
       "733     2021-01-07  000001.SZ  2167.6765  2218.7590  2135.4723  2209.8751   \n",
       "734     2021-01-08  000001.SZ  2209.8751  2232.0849  2144.3562  2204.3226   \n",
       "...            ...        ...        ...        ...        ...        ...   \n",
       "1730540 2021-04-26  002999.SZ    12.3500    12.6800    12.2300    12.3200   \n",
       "1730541 2021-04-27  002999.SZ    12.5800    12.6300    11.8800    11.9400   \n",
       "1730542 2021-04-28  002999.SZ    11.9900    12.0100    11.7200    11.8400   \n",
       "1730543 2021-04-29  002999.SZ    12.0000    12.1900    11.7700    11.7900   \n",
       "1730544 2021-04-30  002999.SZ    11.7900    11.8400    11.5100    11.5400   \n",
       "\n",
       "             volume      pred  \n",
       "730      1554216.43  0.000274  \n",
       "731      1821352.10 -0.000640  \n",
       "732      1934945.12  0.000760  \n",
       "733      1584185.30  0.000760  \n",
       "734      1195473.22  0.001234  \n",
       "...             ...       ...  \n",
       "1730540    54335.98  0.000000  \n",
       "1730541    57535.60  0.000000  \n",
       "1730542    40535.97  0.000000  \n",
       "1730543    49342.27  0.000000  \n",
       "1730544    37447.64  0.000000  \n",
       "\n",
       "[111632 rows x 8 columns]"
      ]
     },
     "execution_count": 100,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pyb_data_pe"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Backtesting: 2021-01-01 00:00:00 to 2021-05-01 00:00:00\n",
      "\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Test split: 2021-01-04 00:00:00 to 2021-04-30 00:00:00\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0% (0 of 79) |                         | Elapsed Time: 0:00:00 ETA:  --:--:--\n",
      "  1% (1 of 79) |                         | Elapsed Time: 0:00:03 ETA:   0:04:37\n",
      " 13% (11 of 79) |###                     | Elapsed Time: 0:00:07 ETA:   0:00:48\n",
      " 26% (21 of 79) |######                  | Elapsed Time: 0:00:09 ETA:   0:00:27\n",
      " 39% (31 of 79) |#########               | Elapsed Time: 0:00:11 ETA:   0:00:18\n",
      " 51% (41 of 79) |############            | Elapsed Time: 0:00:13 ETA:   0:00:12\n",
      " 64% (51 of 79) |###############         | Elapsed Time: 0:00:15 ETA:   0:00:08\n",
      " 77% (61 of 79) |##################      | Elapsed Time: 0:00:17 ETA:   0:00:05\n",
      " 89% (71 of 79) |#####################   | Elapsed Time: 0:00:19 ETA:   0:00:02\n",
      "100% (79 of 79) |########################| Elapsed Time: 0:00:21 ETA:  00:00:00\n",
      "100% (79 of 79) |########################| Elapsed Time: 0:00:21 Time:  0:00:21\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Finished backtest: 0:00:22\n"
     ]
    }
   ],
   "source": [
    "pybroker.register_columns('pred')\n",
    "config = StrategyConfig(initial_cash=10000000)\n",
    "strategy = Strategy(pyb_data_pe, '2021-01-01', '2021-05-01', config)\n",
    "strategy.add_execution(rotate, pyb_data_pe.symbol.unique())\n",
    "strategy.set_before_exec(rand_ld)\n",
    "result = strategy.backtest()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 144,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>type</th>\n",
       "      <th>symbol</th>\n",
       "      <th>entry_date</th>\n",
       "      <th>exit_date</th>\n",
       "      <th>entry</th>\n",
       "      <th>exit</th>\n",
       "      <th>shares</th>\n",
       "      <th>pnl</th>\n",
       "      <th>return_pct</th>\n",
       "      <th>agg_pnl</th>\n",
       "      <th>bars</th>\n",
       "      <th>pnl_per_bar</th>\n",
       "      <th>stop</th>\n",
       "      <th>mae</th>\n",
       "      <th>mfe</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>id</th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>long</td>\n",
       "      <td>000001.SZ</td>\n",
       "      <td>2021-01-05</td>\n",
       "      <td>2021-01-06</td>\n",
       "      <td>2014.43</td>\n",
       "      <td>2085.50</td>\n",
       "      <td>968</td>\n",
       "      <td>68795.76</td>\n",
       "      <td>3.53</td>\n",
       "      <td>68795.76</td>\n",
       "      <td>1</td>\n",
       "      <td>68795.76</td>\n",
       "      <td>None</td>\n",
       "      <td>-37.76</td>\n",
       "      <td>71.07</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>long</td>\n",
       "      <td>000002.SZ</td>\n",
       "      <td>2021-01-05</td>\n",
       "      <td>2021-01-06</td>\n",
       "      <td>4234.59</td>\n",
       "      <td>4346.16</td>\n",
       "      <td>467</td>\n",
       "      <td>52103.19</td>\n",
       "      <td>2.63</td>\n",
       "      <td>120898.95</td>\n",
       "      <td>1</td>\n",
       "      <td>52103.19</td>\n",
       "      <td>None</td>\n",
       "      <td>-63.87</td>\n",
       "      <td>111.57</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>long</td>\n",
       "      <td>000004.SZ</td>\n",
       "      <td>2021-01-05</td>\n",
       "      <td>2021-01-06</td>\n",
       "      <td>84.69</td>\n",
       "      <td>83.05</td>\n",
       "      <td>23412</td>\n",
       "      <td>-38395.68</td>\n",
       "      <td>-1.94</td>\n",
       "      <td>82503.27</td>\n",
       "      <td>1</td>\n",
       "      <td>-38395.68</td>\n",
       "      <td>None</td>\n",
       "      <td>-1.64</td>\n",
       "      <td>0.98</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>long</td>\n",
       "      <td>000005.SZ</td>\n",
       "      <td>2021-01-05</td>\n",
       "      <td>2021-01-06</td>\n",
       "      <td>23.08</td>\n",
       "      <td>22.57</td>\n",
       "      <td>85633</td>\n",
       "      <td>-43672.83</td>\n",
       "      <td>-2.21</td>\n",
       "      <td>38830.44</td>\n",
       "      <td>1</td>\n",
       "      <td>-43672.83</td>\n",
       "      <td>None</td>\n",
       "      <td>-0.51</td>\n",
       "      <td>0.28</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>long</td>\n",
       "      <td>000006.SZ</td>\n",
       "      <td>2021-01-05</td>\n",
       "      <td>2021-01-06</td>\n",
       "      <td>199.15</td>\n",
       "      <td>197.32</td>\n",
       "      <td>9924</td>\n",
       "      <td>-18160.92</td>\n",
       "      <td>-0.92</td>\n",
       "      <td>20669.52</td>\n",
       "      <td>1</td>\n",
       "      <td>-18160.92</td>\n",
       "      <td>None</td>\n",
       "      <td>-2.38</td>\n",
       "      <td>2.37</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>565</th>\n",
       "      <td>long</td>\n",
       "      <td>000004.SZ</td>\n",
       "      <td>2021-04-29</td>\n",
       "      <td>2021-04-30</td>\n",
       "      <td>64.76</td>\n",
       "      <td>64.96</td>\n",
       "      <td>27478</td>\n",
       "      <td>5495.60</td>\n",
       "      <td>0.31</td>\n",
       "      <td>-1165769.24</td>\n",
       "      <td>1</td>\n",
       "      <td>5495.60</td>\n",
       "      <td>None</td>\n",
       "      <td>-1.69</td>\n",
       "      <td>1.69</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>566</th>\n",
       "      <td>long</td>\n",
       "      <td>000008.SZ</td>\n",
       "      <td>2021-04-29</td>\n",
       "      <td>2021-04-30</td>\n",
       "      <td>50.53</td>\n",
       "      <td>50.42</td>\n",
       "      <td>33750</td>\n",
       "      <td>-3712.50</td>\n",
       "      <td>-0.22</td>\n",
       "      <td>-1169481.74</td>\n",
       "      <td>1</td>\n",
       "      <td>-3712.50</td>\n",
       "      <td>None</td>\n",
       "      <td>-0.34</td>\n",
       "      <td>0.34</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>567</th>\n",
       "      <td>long</td>\n",
       "      <td>000020.SZ</td>\n",
       "      <td>2021-04-29</td>\n",
       "      <td>2021-04-30</td>\n",
       "      <td>16.00</td>\n",
       "      <td>15.92</td>\n",
       "      <td>1</td>\n",
       "      <td>-0.08</td>\n",
       "      <td>-0.50</td>\n",
       "      <td>-1169481.82</td>\n",
       "      <td>1</td>\n",
       "      <td>-0.08</td>\n",
       "      <td>None</td>\n",
       "      <td>-0.14</td>\n",
       "      <td>0.15</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>568</th>\n",
       "      <td>long</td>\n",
       "      <td>000040.SZ</td>\n",
       "      <td>2021-04-29</td>\n",
       "      <td>2021-04-30</td>\n",
       "      <td>11.20</td>\n",
       "      <td>11.11</td>\n",
       "      <td>1</td>\n",
       "      <td>-0.09</td>\n",
       "      <td>-0.80</td>\n",
       "      <td>-1169481.91</td>\n",
       "      <td>1</td>\n",
       "      <td>-0.09</td>\n",
       "      <td>None</td>\n",
       "      <td>-0.11</td>\n",
       "      <td>0.11</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>569</th>\n",
       "      <td>long</td>\n",
       "      <td>002198.SZ</td>\n",
       "      <td>2021-04-29</td>\n",
       "      <td>2021-04-30</td>\n",
       "      <td>32.73</td>\n",
       "      <td>29.48</td>\n",
       "      <td>52353</td>\n",
       "      <td>-170147.25</td>\n",
       "      <td>-9.93</td>\n",
       "      <td>-1339629.16</td>\n",
       "      <td>1</td>\n",
       "      <td>-170147.25</td>\n",
       "      <td>None</td>\n",
       "      <td>-3.25</td>\n",
       "      <td>0.00</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>569 rows × 15 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "     type     symbol entry_date  exit_date    entry     exit  shares  \\\n",
       "id                                                                     \n",
       "1    long  000001.SZ 2021-01-05 2021-01-06  2014.43  2085.50     968   \n",
       "2    long  000002.SZ 2021-01-05 2021-01-06  4234.59  4346.16     467   \n",
       "3    long  000004.SZ 2021-01-05 2021-01-06    84.69    83.05   23412   \n",
       "4    long  000005.SZ 2021-01-05 2021-01-06    23.08    22.57   85633   \n",
       "5    long  000006.SZ 2021-01-05 2021-01-06   199.15   197.32    9924   \n",
       "..    ...        ...        ...        ...      ...      ...     ...   \n",
       "565  long  000004.SZ 2021-04-29 2021-04-30    64.76    64.96   27478   \n",
       "566  long  000008.SZ 2021-04-29 2021-04-30    50.53    50.42   33750   \n",
       "567  long  000020.SZ 2021-04-29 2021-04-30    16.00    15.92       1   \n",
       "568  long  000040.SZ 2021-04-29 2021-04-30    11.20    11.11       1   \n",
       "569  long  002198.SZ 2021-04-29 2021-04-30    32.73    29.48   52353   \n",
       "\n",
       "           pnl  return_pct     agg_pnl  bars  pnl_per_bar  stop    mae     mfe  \n",
       "id                                                                              \n",
       "1     68795.76        3.53    68795.76     1     68795.76  None -37.76   71.07  \n",
       "2     52103.19        2.63   120898.95     1     52103.19  None -63.87  111.57  \n",
       "3    -38395.68       -1.94    82503.27     1    -38395.68  None  -1.64    0.98  \n",
       "4    -43672.83       -2.21    38830.44     1    -43672.83  None  -0.51    0.28  \n",
       "5    -18160.92       -0.92    20669.52     1    -18160.92  None  -2.38    2.37  \n",
       "..         ...         ...         ...   ...          ...   ...    ...     ...  \n",
       "565    5495.60        0.31 -1165769.24     1      5495.60  None  -1.69    1.69  \n",
       "566   -3712.50       -0.22 -1169481.74     1     -3712.50  None  -0.34    0.34  \n",
       "567      -0.08       -0.50 -1169481.82     1        -0.08  None  -0.14    0.15  \n",
       "568      -0.09       -0.80 -1169481.91     1        -0.09  None  -0.11    0.11  \n",
       "569 -170147.25       -9.93 -1339629.16     1   -170147.25  None  -3.25    0.00  \n",
       "\n",
       "[569 rows x 15 columns]"
      ]
     },
     "execution_count": 144,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "result.trades"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>name</th>\n",
       "      <th>value</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>trade_count</td>\n",
       "      <td>581.00</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>initial_market_value</td>\n",
       "      <td>10000000.00</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>end_market_value</td>\n",
       "      <td>8596660.57</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>total_pnl</td>\n",
       "      <td>-1440405.70</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>unrealized_pnl</td>\n",
       "      <td>37066.27</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                   name        value\n",
       "0           trade_count       581.00\n",
       "1  initial_market_value  10000000.00\n",
       "2      end_market_value   8596660.57\n",
       "3             total_pnl  -1440405.70\n",
       "4        unrealized_pnl     37066.27"
      ]
     },
     "execution_count": 40,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "result.metrics_df.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 基于机器学习的因子排序策略实现"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 60,
   "metadata": {},
   "outputs": [],
   "source": [
    "def func10(x):\n",
    "    return x.pct_change(periods=10).shift(-10)\n",
    "\n",
    "stock_daily1_d[\"return_s10\"]=stock_daily1_d.groupby(\"symbol\", group_keys=False).close.apply(func10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 61,
   "metadata": {},
   "outputs": [],
   "source": [
    "xy=stock_daily1_d[stock_daily1_d.date<datetime.datetime(2021,1,1)].iloc[:,2:].dropna()\n",
    "xy_x=xy.drop([\"return_s1\",\"return_s10\"],axis=1)\n",
    "xy_y=xy[\"return_s10\"]\n",
    "x1,x2,y1,y2=train_test_split(xy_x,xy_y,test_size=0.8)#分割数据出训练集与测试集，0.7是两者行数的比例"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 62,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.017790 seconds.\n",
      "You can set `force_col_wise=true` to remove the overhead.\n",
      "[LightGBM] [Info] Total Bins 9180\n",
      "[LightGBM] [Info] Number of data points in the train set: 114070, number of used features: 36\n",
      "[LightGBM] [Info] Start training from score 0.003236\n",
      "0.051109060658275474\n"
     ]
    }
   ],
   "source": [
    "gbm = lgb.LGBMRegressor(objective='regression',num_leaves=31,learning_rate=0.2,n_estimators=50)\n",
    "gbm.fit(x1,y1)\n",
    "y_pred = gbm.predict(x2)\n",
    "print(r2_score(y2,y_pred))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 65,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "-0.016545272675347666\n"
     ]
    }
   ],
   "source": [
    "xy_t=stock_daily1_d[(stock_daily1_d.date>datetime.datetime(2021,1,1))&(stock_daily1_d.date<datetime.datetime(2021,5,1))].iloc[:,2:].dropna()\n",
    "xy_x=xy_t.drop([\"return_s1\",\"return_s10\"],axis=1)\n",
    "xy_y=xy_t[\"return_s10\"]\n",
    "y_pred = gbm.predict(xy_x)\n",
    "print(r2_score(xy_y,y_pred))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "metadata": {},
   "outputs": [],
   "source": [
    "y_pred=pd.Series(y_pred,index=xy_x.index,name=\"pred\")\n",
    "pyb_data_pe=stock_daily1_d[(stock_daily1_d.date>datetime.datetime(2021,1,1))&(stock_daily1_d.date<datetime.datetime(2021,5,1))].iloc[:,0:7]\n",
    "pyb_data_pe = pd.concat([pyb_data_pe, y_pred], axis=1)\n",
    "pyb_data_pe.fillna(0,inplace=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 70,
   "metadata": {},
   "outputs": [],
   "source": [
    "#pyb_data_pe"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "metadata": {},
   "outputs": [],
   "source": [
    "def buy_highest_volume(ctx):\n",
    "    # If there are no long positions across all tickers being traded:\n",
    "    if not tuple(ctx.long_positions()):\n",
    "        ctx.buy_shares = ctx.calc_target_shares(0.2)\n",
    "        ctx.hold_bars = 10\n",
    "        ctx.score = ctx.pred[-1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 66,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Backtesting: 2021-01-01 00:00:00 to 2021-05-01 00:00:00\n",
      "\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Test split: 2021-01-04 00:00:00 to 2021-04-30 00:00:00\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0% (0 of 79) |                         | Elapsed Time: 0:00:00 ETA:  --:--:--\n",
      "  1% (1 of 79) |                         | Elapsed Time: 0:00:03 ETA:   0:03:55\n",
      " 13% (11 of 79) |###                     | Elapsed Time: 0:00:05 ETA:   0:00:33\n",
      " 26% (21 of 79) |######                  | Elapsed Time: 0:00:05 ETA:   0:00:15\n",
      " 39% (31 of 79) |#########               | Elapsed Time: 0:00:06 ETA:   0:00:09\n",
      " 51% (41 of 79) |############            | Elapsed Time: 0:00:06 ETA:   0:00:05\n",
      " 64% (51 of 79) |###############         | Elapsed Time: 0:00:06 ETA:   0:00:03\n",
      " 77% (61 of 79) |##################      | Elapsed Time: 0:00:06 ETA:   0:00:02\n",
      " 89% (71 of 79) |#####################   | Elapsed Time: 0:00:07 ETA:   0:00:00\n",
      "100% (79 of 79) |########################| Elapsed Time: 0:00:07 ETA:  00:00:00\n",
      "100% (79 of 79) |########################| Elapsed Time: 0:00:07 Time:  0:00:07\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Finished backtest: 0:00:08\n"
     ]
    }
   ],
   "source": [
    "pybroker.register_columns('pred')\n",
    "strategy = Strategy(pyb_data_pe, '2021-01-01', '2021-05-01')\n",
    "strategy.add_execution(buy_highest_volume, pyb_data_pe.symbol.unique())\n",
    "result = strategy.backtest()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 68,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>name</th>\n",
       "      <th>value</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>trade_count</td>\n",
       "      <td>49.00</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>initial_market_value</td>\n",
       "      <td>100000.00</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>end_market_value</td>\n",
       "      <td>93075.55</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>total_pnl</td>\n",
       "      <td>-7356.97</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>unrealized_pnl</td>\n",
       "      <td>432.52</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                   name      value\n",
       "0           trade_count      49.00\n",
       "1  initial_market_value  100000.00\n",
       "2      end_market_value   93075.55\n",
       "3             total_pnl   -7356.97\n",
       "4        unrealized_pnl     432.52"
      ]
     },
     "execution_count": 68,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "result.metrics_df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 74,
   "metadata": {},
   "outputs": [],
   "source": [
    "#result.trades"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 63,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "730       -0.023118\n",
       "731        0.076500\n",
       "732        0.017382\n",
       "733       -0.002513\n",
       "734        0.026700\n",
       "             ...   \n",
       "1726342   -0.070123\n",
       "1726343   -0.033101\n",
       "1726344    0.005406\n",
       "1726345    0.002943\n",
       "1726346   -0.010211\n",
       "Name: return_s1, Length: 64898, dtype: float64"
      ]
     },
     "execution_count": 63,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "xy_y"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 基于akshare数据源的机器学习算法策略"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 71,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "b8852f43dbc844b7b5f4f0604bb5977b",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "96587ffbbabd4f009c4643b328ae3c54",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "    序号      代码     名称     最新价   涨跌幅    涨跌额      成交量           成交额     振幅  \\\n",
      "0    1  600686   金龙汽车   13.07  2.91   0.37   265232  3.461206e+08   6.30   \n",
      "1    2  600375   汉马科技    6.67  1.37   0.09   826696  5.562643e+08  11.55   \n",
      "2    3  000550   江铃汽车   21.17  0.43   0.09    45711  9.747963e+07   3.32   \n",
      "3    4  600166   福田汽车    2.68  0.00   0.00   372367  9.971643e+07   1.87   \n",
      "4    5  200550  江  铃Ｂ    9.93 -0.10  -0.01      283  2.812130e+05   0.91   \n",
      "5    6  000868   安凯客车    5.83 -0.17  -0.01   134752  7.844414e+07   2.57   \n",
      "6    7  601777   千里科技    8.37 -0.24  -0.02   263089  2.187911e+08   3.93   \n",
      "7    8  600303   曙光股份    4.07 -0.25  -0.01   102177  4.212229e+07   2.94   \n",
      "8    9  000957   中通客车   11.34 -0.61  -0.07   166388  1.898896e+08   2.45   \n",
      "9   10  600006   东风股份    7.65 -1.03  -0.08   580774  4.495370e+08   3.23   \n",
      "10  11  600418   江淮汽车   40.36 -1.10  -0.45   570086  2.335257e+09   5.02   \n",
      "11  12  000800   一汽解放    7.19 -1.24  -0.09    88571  6.410694e+07   1.65   \n",
      "12  13  000951   中国重汽   17.24 -1.26  -0.22    91651  1.596724e+08   3.72   \n",
      "13  14  200625  长  安Ｂ    3.79 -1.30  -0.05    17156  6.537141e+06   1.30   \n",
      "14  15  301039   中集车辆    8.26 -1.31  -0.11    60181  5.010247e+07   1.55   \n",
      "15  16  600104   上汽集团   16.54 -1.96  -0.33   208803  3.472286e+08   1.96   \n",
      "16  17  601238   广汽集团    7.88 -2.11  -0.17   152408  1.210570e+08   2.48   \n",
      "17  18  600066   宇通客车   24.53 -2.27  -0.57    90542  2.238970e+08   2.63   \n",
      "18  19  000625   长安汽车   12.50 -2.42  -0.31   784466  9.853338e+08   1.95   \n",
      "19  20  601633   长城汽车   22.82 -2.56  -0.60   145935  3.358966e+08   2.69   \n",
      "20  21  000572   海马汽车    3.95 -2.71  -0.11   274798  1.091853e+08   2.22   \n",
      "21  22  600841   动力新科    5.92 -2.79  -0.17   354532  2.095962e+08   3.45   \n",
      "22  23  600733   北汽蓝谷    7.36 -2.90  -0.22   888899  6.581351e+08   2.51   \n",
      "23  24  000980   众泰汽车    2.09 -3.69  -0.08  1056309  2.220670e+08   3.69   \n",
      "24  25  601127    赛力斯  139.60 -4.30  -6.27   345386  4.907965e+09   4.03   \n",
      "25  26  002594    比亚迪  384.19 -5.14 -20.81   207884  8.134126e+09   4.99   \n",
      "\n",
      "        最高      最低      今开      昨收   换手率  市盈率-动态    市净率  \n",
      "0    13.34   12.54   12.60   12.70  3.70   50.10   2.98  \n",
      "1     7.15    6.39    6.43    6.58  7.60  227.32   3.23  \n",
      "2    21.60   20.90   21.01   21.08  0.88   14.93   1.58  \n",
      "3     2.70    2.65    2.66    2.68  0.57   12.18   1.44  \n",
      "4    10.02    9.93    9.96    9.94  0.01    6.46   0.68  \n",
      "5     5.91    5.76    5.80    5.84  1.84  313.33   6.34  \n",
      "6     8.49    8.16    8.37    8.39  0.58  471.97   3.57  \n",
      "7     4.18    4.06    4.12    4.08  1.51   -9.56   2.19  \n",
      "8    11.56   11.28   11.33   11.41  2.81   21.97   2.23  \n",
      "9     7.87    7.62    7.76    7.73  2.90   25.30   1.82  \n",
      "10   41.85   39.80   40.35   40.81  2.61  -98.82   7.88  \n",
      "11    7.29    7.17    7.27    7.28  0.18  304.54   1.34  \n",
      "12   17.86   17.21   17.42   17.46  0.78   16.32   1.30  \n",
      "13    3.84    3.79    3.84    3.84  0.10    6.41   0.44  \n",
      "14    8.38    8.25    8.35    8.37  0.41   21.67   1.06  \n",
      "15   16.83   16.50   16.75   16.87  0.18   15.83   0.66  \n",
      "16    8.06    7.86    8.05    8.05  0.21  -27.46   0.71  \n",
      "17   25.18   24.52   25.16   25.10  0.41   17.98   4.54  \n",
      "18   12.73   12.48   12.72   12.81  0.95   22.90   1.59  \n",
      "19   23.41   22.78   23.38   23.42  0.24   27.89   2.37  \n",
      "20    4.02    3.93    3.98    4.06  1.67  -54.13   3.76  \n",
      "21    6.02    5.81    5.83    6.09  3.40   -9.77   2.48  \n",
      "22    7.54    7.35    7.50    7.58  1.82  -10.76   9.01  \n",
      "23    2.15    2.07    2.14    2.17  2.12  -25.57  79.96  \n",
      "24  144.98  139.10  144.00  145.87  2.29   76.23  10.59  \n",
      "25  403.00  382.80  401.81  405.00  1.79   31.88   5.35  \n"
     ]
    }
   ],
   "source": [
    "stock_board_industry_cons_em_df = ak.stock_board_industry_cons_em(symbol=\"汽车整车\")\n",
    "print(stock_board_industry_cons_em_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 72,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array(['600686', '600375', '000550', '600166', '200550', '000868',\n",
       "       '601777', '600303', '000957', '600006', '600418', '000800',\n",
       "       '000951', '200625', '301039', '600104', '601238', '600066',\n",
       "       '000625', '601633', '000572', '600841', '600733', '000980',\n",
       "       '601127', '002594'], dtype=object)"
      ]
     },
     "execution_count": 72,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "stock_board_industry_cons_em_df.代码.unique()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 73,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loaded cached bar data.\n",
      "\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>date</th>\n",
       "      <th>symbol</th>\n",
       "      <th>open</th>\n",
       "      <th>high</th>\n",
       "      <th>low</th>\n",
       "      <th>close</th>\n",
       "      <th>volume</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>19674</th>\n",
       "      <td>2016-01-04</td>\n",
       "      <td>000957</td>\n",
       "      <td>44.77</td>\n",
       "      <td>44.77</td>\n",
       "      <td>40.37</td>\n",
       "      <td>40.37</td>\n",
       "      <td>49046</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>19675</th>\n",
       "      <td>2016-01-05</td>\n",
       "      <td>000957</td>\n",
       "      <td>38.38</td>\n",
       "      <td>41.65</td>\n",
       "      <td>37.87</td>\n",
       "      <td>41.65</td>\n",
       "      <td>101643</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>19676</th>\n",
       "      <td>2016-01-06</td>\n",
       "      <td>000957</td>\n",
       "      <td>41.72</td>\n",
       "      <td>44.17</td>\n",
       "      <td>41.47</td>\n",
       "      <td>43.83</td>\n",
       "      <td>86237</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>19677</th>\n",
       "      <td>2016-01-07</td>\n",
       "      <td>000957</td>\n",
       "      <td>42.75</td>\n",
       "      <td>42.88</td>\n",
       "      <td>39.85</td>\n",
       "      <td>39.96</td>\n",
       "      <td>18357</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>19678</th>\n",
       "      <td>2016-01-08</td>\n",
       "      <td>000957</td>\n",
       "      <td>41.56</td>\n",
       "      <td>42.53</td>\n",
       "      <td>38.23</td>\n",
       "      <td>41.18</td>\n",
       "      <td>58311</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>39197</th>\n",
       "      <td>2024-12-25</td>\n",
       "      <td>200550</td>\n",
       "      <td>29.18</td>\n",
       "      <td>29.20</td>\n",
       "      <td>29.10</td>\n",
       "      <td>29.16</td>\n",
       "      <td>1910</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>39198</th>\n",
       "      <td>2024-12-26</td>\n",
       "      <td>200550</td>\n",
       "      <td>29.13</td>\n",
       "      <td>29.26</td>\n",
       "      <td>29.12</td>\n",
       "      <td>29.17</td>\n",
       "      <td>1985</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>39199</th>\n",
       "      <td>2024-12-27</td>\n",
       "      <td>200550</td>\n",
       "      <td>29.20</td>\n",
       "      <td>29.25</td>\n",
       "      <td>29.16</td>\n",
       "      <td>29.18</td>\n",
       "      <td>693</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>39200</th>\n",
       "      <td>2024-12-30</td>\n",
       "      <td>200550</td>\n",
       "      <td>29.14</td>\n",
       "      <td>29.18</td>\n",
       "      <td>29.11</td>\n",
       "      <td>29.11</td>\n",
       "      <td>1665</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>39201</th>\n",
       "      <td>2024-12-31</td>\n",
       "      <td>200550</td>\n",
       "      <td>29.14</td>\n",
       "      <td>29.14</td>\n",
       "      <td>28.98</td>\n",
       "      <td>29.06</td>\n",
       "      <td>2136</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>54374 rows × 7 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "            date  symbol   open   high    low  close  volume\n",
       "19674 2016-01-04  000957  44.77  44.77  40.37  40.37   49046\n",
       "19675 2016-01-05  000957  38.38  41.65  37.87  41.65  101643\n",
       "19676 2016-01-06  000957  41.72  44.17  41.47  43.83   86237\n",
       "19677 2016-01-07  000957  42.75  42.88  39.85  39.96   18357\n",
       "19678 2016-01-08  000957  41.56  42.53  38.23  41.18   58311\n",
       "...          ...     ...    ...    ...    ...    ...     ...\n",
       "39197 2024-12-25  200550  29.18  29.20  29.10  29.16    1910\n",
       "39198 2024-12-26  200550  29.13  29.26  29.12  29.17    1985\n",
       "39199 2024-12-27  200550  29.20  29.25  29.16  29.18     693\n",
       "39200 2024-12-30  200550  29.14  29.18  29.11  29.11    1665\n",
       "39201 2024-12-31  200550  29.14  29.14  28.98  29.06    2136\n",
       "\n",
       "[54374 rows x 7 columns]"
      ]
     },
     "execution_count": 73,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df = akshare.query(\n",
    "    symbols=stock_board_industry_cons_em_df.代码.unique(),\n",
    "    start_date='1/1/2016',\n",
    "    end_date='1/1/2025',\n",
    "    adjust=\"hfq\",\n",
    "    timeframe=\"1d\",\n",
    ")\n",
    "df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 74,
   "metadata": {},
   "outputs": [],
   "source": [
    "z=df.groupby(\"symbol\", group_keys=False).close.apply(compute_bb)\n",
    "df=df.join(z)\n",
    "df[\"close-o\"]=df[\"close\"]-df[\"open\"]\n",
    "df[\"high-l\"]=df[\"high\"]-df[\"low\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 75,
   "metadata": {},
   "outputs": [],
   "source": [
    "df[\"return_s1\"]=df.groupby(\"symbol\", group_keys=False).close.apply(func0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 76,
   "metadata": {},
   "outputs": [],
   "source": [
    "xy=df[df.date<datetime.datetime(2024,1,1)].iloc[:,2:].dropna()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 77,
   "metadata": {},
   "outputs": [],
   "source": [
    "xy_x=xy.drop(\"return_s1\",axis=1)\n",
    "xy_y=xy[\"return_s1\"]\n",
    "x1,x2,y1,y2=train_test_split(xy_x,xy_y,test_size=0.8)#分割数据出训练集与测试集，0.7是两者行数的比例"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 78,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.000570 seconds.\n",
      "You can set `force_col_wise=true` to remove the overhead.\n",
      "[LightGBM] [Info] Total Bins 5355\n",
      "[LightGBM] [Info] Number of data points in the train set: 6296, number of used features: 21\n",
      "[LightGBM] [Info] Start training from score -0.000248\n",
      "[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.000765 seconds.\n",
      "You can set `force_col_wise=true` to remove the overhead.\n",
      "[LightGBM] [Info] Total Bins 5355\n",
      "[LightGBM] [Info] Number of data points in the train set: 6297, number of used features: 21\n",
      "[LightGBM] [Info] Start training from score -0.000156\n",
      "[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.000739 seconds.\n",
      "You can set `force_col_wise=true` to remove the overhead.\n",
      "[LightGBM] [Info] Total Bins 5355\n",
      "[LightGBM] [Info] Number of data points in the train set: 6297, number of used features: 21\n",
      "[LightGBM] [Info] Start training from score 0.000421\n",
      "[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.000905 seconds.\n",
      "You can set `force_col_wise=true` to remove the overhead.\n",
      "[LightGBM] [Info] Total Bins 5355\n",
      "[LightGBM] [Info] Number of data points in the train set: 6296, number of used features: 21\n",
      "[LightGBM] [Info] Start training from score -0.000248\n",
      "[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.000595 seconds.\n",
      "You can set `force_col_wise=true` to remove the overhead.\n",
      "[LightGBM] [Info] Total Bins 5355\n",
      "[LightGBM] [Info] Number of data points in the train set: 6297, number of used features: 21\n",
      "[LightGBM] [Info] Start training from score -0.000156\n",
      "[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.000636 seconds.\n",
      "You can set `force_col_wise=true` to remove the overhead.\n",
      "[LightGBM] [Info] Total Bins 5355\n",
      "[LightGBM] [Info] Number of data points in the train set: 6297, number of used features: 21\n",
      "[LightGBM] [Info] Start training from score 0.000421\n",
      "[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.000761 seconds.\n",
      "You can set `force_col_wise=true` to remove the overhead.\n",
      "[LightGBM] [Info] Total Bins 5355\n",
      "[LightGBM] [Info] Number of data points in the train set: 6296, number of used features: 21\n",
      "[LightGBM] [Info] Start training from score -0.000248\n",
      "[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.000783 seconds.\n",
      "You can set `force_col_wise=true` to remove the overhead.\n",
      "[LightGBM] [Info] Total Bins 5355\n",
      "[LightGBM] [Info] Number of data points in the train set: 6297, number of used features: 21\n",
      "[LightGBM] [Info] Start training from score -0.000156\n",
      "[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.000727 seconds.\n",
      "You can set `force_col_wise=true` to remove the overhead.\n",
      "[LightGBM] [Info] Total Bins 5355\n",
      "[LightGBM] [Info] Number of data points in the train set: 6297, number of used features: 21\n",
      "[LightGBM] [Info] Start training from score 0.000421\n",
      "[LightGBM] [Info] Auto-choosing row-wise multi-threading, the overhead of testing was 0.003132 seconds.\n",
      "You can set `force_row_wise=true` to remove the overhead.\n",
      "And if memory is not enough, you can set `force_col_wise=true`.\n",
      "[LightGBM] [Info] Total Bins 5355\n",
      "[LightGBM] [Info] Number of data points in the train set: 6296, number of used features: 21\n",
      "[LightGBM] [Info] Start training from score -0.000248\n",
      "[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.000752 seconds.\n",
      "You can set `force_col_wise=true` to remove the overhead.\n",
      "[LightGBM] [Info] Total Bins 5355\n",
      "[LightGBM] [Info] Number of data points in the train set: 6297, number of used features: 21\n",
      "[LightGBM] [Info] Start training from score -0.000156\n",
      "[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.000881 seconds.\n",
      "You can set `force_col_wise=true` to remove the overhead.\n",
      "[LightGBM] [Info] Total Bins 5355\n",
      "[LightGBM] [Info] Number of data points in the train set: 6297, number of used features: 21\n",
      "[LightGBM] [Info] Start training from score 0.000421\n",
      "[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.000753 seconds.\n",
      "You can set `force_col_wise=true` to remove the overhead.\n",
      "[LightGBM] [Info] Total Bins 5355\n",
      "[LightGBM] [Info] Number of data points in the train set: 6296, number of used features: 21\n",
      "[LightGBM] [Info] Start training from score -0.000248\n",
      "[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.000626 seconds.\n",
      "You can set `force_col_wise=true` to remove the overhead.\n",
      "[LightGBM] [Info] Total Bins 5355\n",
      "[LightGBM] [Info] Number of data points in the train set: 6297, number of used features: 21\n",
      "[LightGBM] [Info] Start training from score -0.000156\n",
      "[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.000803 seconds.\n",
      "You can set `force_col_wise=true` to remove the overhead.\n",
      "[LightGBM] [Info] Total Bins 5355\n",
      "[LightGBM] [Info] Number of data points in the train set: 6297, number of used features: 21\n",
      "[LightGBM] [Info] Start training from score 0.000421\n",
      "[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.001132 seconds.\n",
      "You can set `force_col_wise=true` to remove the overhead.\n",
      "[LightGBM] [Info] Total Bins 5355\n",
      "[LightGBM] [Info] Number of data points in the train set: 6296, number of used features: 21\n",
      "[LightGBM] [Info] Start training from score -0.000248\n",
      "[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.000814 seconds.\n",
      "You can set `force_col_wise=true` to remove the overhead.\n",
      "[LightGBM] [Info] Total Bins 5355\n",
      "[LightGBM] [Info] Number of data points in the train set: 6297, number of used features: 21\n",
      "[LightGBM] [Info] Start training from score -0.000156\n",
      "[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.000673 seconds.\n",
      "You can set `force_col_wise=true` to remove the overhead.\n",
      "[LightGBM] [Info] Total Bins 5355\n",
      "[LightGBM] [Info] Number of data points in the train set: 6297, number of used features: 21\n",
      "[LightGBM] [Info] Start training from score 0.000421\n",
      "[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.000741 seconds.\n",
      "You can set `force_col_wise=true` to remove the overhead.\n",
      "[LightGBM] [Info] Total Bins 5355\n",
      "[LightGBM] [Info] Number of data points in the train set: 6296, number of used features: 21\n",
      "[LightGBM] [Info] Start training from score -0.000248\n",
      "[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.000732 seconds.\n",
      "You can set `force_col_wise=true` to remove the overhead.\n",
      "[LightGBM] [Info] Total Bins 5355\n",
      "[LightGBM] [Info] Number of data points in the train set: 6297, number of used features: 21\n",
      "[LightGBM] [Info] Start training from score -0.000156\n",
      "[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.000866 seconds.\n",
      "You can set `force_col_wise=true` to remove the overhead.\n",
      "[LightGBM] [Info] Total Bins 5355\n",
      "[LightGBM] [Info] Number of data points in the train set: 6297, number of used features: 21\n",
      "[LightGBM] [Info] Start training from score 0.000421\n",
      "[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.000723 seconds.\n",
      "You can set `force_col_wise=true` to remove the overhead.\n",
      "[LightGBM] [Info] Total Bins 5355\n",
      "[LightGBM] [Info] Number of data points in the train set: 6296, number of used features: 21\n",
      "[LightGBM] [Info] Start training from score -0.000248\n",
      "[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.000647 seconds.\n",
      "You can set `force_col_wise=true` to remove the overhead.\n",
      "[LightGBM] [Info] Total Bins 5355\n",
      "[LightGBM] [Info] Number of data points in the train set: 6297, number of used features: 21\n",
      "[LightGBM] [Info] Start training from score -0.000156\n",
      "[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.000567 seconds.\n",
      "You can set `force_col_wise=true` to remove the overhead.\n",
      "[LightGBM] [Info] Total Bins 5355\n",
      "[LightGBM] [Info] Number of data points in the train set: 6297, number of used features: 21\n",
      "[LightGBM] [Info] Start training from score 0.000421\n",
      "[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.000753 seconds.\n",
      "You can set `force_col_wise=true` to remove the overhead.\n",
      "[LightGBM] [Info] Total Bins 5355\n",
      "[LightGBM] [Info] Number of data points in the train set: 6296, number of used features: 21\n",
      "[LightGBM] [Info] Start training from score -0.000248\n",
      "[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.000729 seconds.\n",
      "You can set `force_col_wise=true` to remove the overhead.\n",
      "[LightGBM] [Info] Total Bins 5355\n",
      "[LightGBM] [Info] Number of data points in the train set: 6297, number of used features: 21\n",
      "[LightGBM] [Info] Start training from score -0.000156\n",
      "[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.000723 seconds.\n",
      "You can set `force_col_wise=true` to remove the overhead.\n",
      "[LightGBM] [Info] Total Bins 5355\n",
      "[LightGBM] [Info] Number of data points in the train set: 6297, number of used features: 21\n",
      "[LightGBM] [Info] Start training from score 0.000421\n",
      "[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.000933 seconds.\n",
      "You can set `force_col_wise=true` to remove the overhead.\n",
      "[LightGBM] [Info] Total Bins 5355\n",
      "[LightGBM] [Info] Number of data points in the train set: 6296, number of used features: 21\n",
      "[LightGBM] [Info] Start training from score -0.000248\n",
      "[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.000810 seconds.\n",
      "You can set `force_col_wise=true` to remove the overhead.\n",
      "[LightGBM] [Info] Total Bins 5355\n",
      "[LightGBM] [Info] Number of data points in the train set: 6297, number of used features: 21\n",
      "[LightGBM] [Info] Start training from score -0.000156\n",
      "[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.000737 seconds.\n",
      "You can set `force_col_wise=true` to remove the overhead.\n",
      "[LightGBM] [Info] Total Bins 5355\n",
      "[LightGBM] [Info] Number of data points in the train set: 6297, number of used features: 21\n",
      "[LightGBM] [Info] Start training from score 0.000421\n",
      "[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.000710 seconds.\n",
      "You can set `force_col_wise=true` to remove the overhead.\n",
      "[LightGBM] [Info] Total Bins 5355\n",
      "[LightGBM] [Info] Number of data points in the train set: 6296, number of used features: 21\n",
      "[LightGBM] [Info] Start training from score -0.000248\n",
      "[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.000586 seconds.\n",
      "You can set `force_col_wise=true` to remove the overhead.\n",
      "[LightGBM] [Info] Total Bins 5355\n",
      "[LightGBM] [Info] Number of data points in the train set: 6297, number of used features: 21\n",
      "[LightGBM] [Info] Start training from score -0.000156\n",
      "[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.001218 seconds.\n",
      "You can set `force_col_wise=true` to remove the overhead.\n",
      "[LightGBM] [Info] Total Bins 5355\n",
      "[LightGBM] [Info] Number of data points in the train set: 6297, number of used features: 21\n",
      "[LightGBM] [Info] Start training from score 0.000421\n",
      "[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.000746 seconds.\n",
      "You can set `force_col_wise=true` to remove the overhead.\n",
      "[LightGBM] [Info] Total Bins 5355\n",
      "[LightGBM] [Info] Number of data points in the train set: 6296, number of used features: 21\n",
      "[LightGBM] [Info] Start training from score -0.000248\n",
      "[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.001818 seconds.\n",
      "You can set `force_col_wise=true` to remove the overhead.\n",
      "[LightGBM] [Info] Total Bins 5355\n",
      "[LightGBM] [Info] Number of data points in the train set: 6297, number of used features: 21\n",
      "[LightGBM] [Info] Start training from score -0.000156\n",
      "[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.000733 seconds.\n",
      "You can set `force_col_wise=true` to remove the overhead.\n",
      "[LightGBM] [Info] Total Bins 5355\n",
      "[LightGBM] [Info] Number of data points in the train set: 6297, number of used features: 21\n",
      "[LightGBM] [Info] Start training from score 0.000421\n",
      "[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.000763 seconds.\n",
      "You can set `force_col_wise=true` to remove the overhead.\n",
      "[LightGBM] [Info] Total Bins 5355\n",
      "[LightGBM] [Info] Number of data points in the train set: 6296, number of used features: 21\n",
      "[LightGBM] [Info] Start training from score -0.000248\n",
      "[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.000808 seconds.\n",
      "You can set `force_col_wise=true` to remove the overhead.\n",
      "[LightGBM] [Info] Total Bins 5355\n",
      "[LightGBM] [Info] Number of data points in the train set: 6297, number of used features: 21\n",
      "[LightGBM] [Info] Start training from score -0.000156\n",
      "[LightGBM] [Info] Auto-choosing row-wise multi-threading, the overhead of testing was 0.000911 seconds.\n",
      "You can set `force_row_wise=true` to remove the overhead.\n",
      "And if memory is not enough, you can set `force_col_wise=true`.\n",
      "[LightGBM] [Info] Total Bins 5355\n",
      "[LightGBM] [Info] Number of data points in the train set: 6297, number of used features: 21\n",
      "[LightGBM] [Info] Start training from score 0.000421\n",
      "[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.000736 seconds.\n",
      "You can set `force_col_wise=true` to remove the overhead.\n",
      "[LightGBM] [Info] Total Bins 5355\n",
      "[LightGBM] [Info] Number of data points in the train set: 6296, number of used features: 21\n",
      "[LightGBM] [Info] Start training from score -0.000248\n",
      "[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.000712 seconds.\n",
      "You can set `force_col_wise=true` to remove the overhead.\n",
      "[LightGBM] [Info] Total Bins 5355\n",
      "[LightGBM] [Info] Number of data points in the train set: 6297, number of used features: 21\n",
      "[LightGBM] [Info] Start training from score -0.000156\n",
      "[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.002848 seconds.\n",
      "You can set `force_col_wise=true` to remove the overhead.\n",
      "[LightGBM] [Info] Total Bins 5355\n",
      "[LightGBM] [Info] Number of data points in the train set: 6297, number of used features: 21\n",
      "[LightGBM] [Info] Start training from score 0.000421\n",
      "[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.000650 seconds.\n",
      "You can set `force_col_wise=true` to remove the overhead.\n",
      "[LightGBM] [Info] Total Bins 5355\n",
      "[LightGBM] [Info] Number of data points in the train set: 6296, number of used features: 21\n",
      "[LightGBM] [Info] Start training from score -0.000248\n",
      "[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.000684 seconds.\n",
      "You can set `force_col_wise=true` to remove the overhead.\n",
      "[LightGBM] [Info] Total Bins 5355\n",
      "[LightGBM] [Info] Number of data points in the train set: 6297, number of used features: 21\n",
      "[LightGBM] [Info] Start training from score -0.000156\n",
      "[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.000690 seconds.\n",
      "You can set `force_col_wise=true` to remove the overhead.\n",
      "[LightGBM] [Info] Total Bins 5355\n",
      "[LightGBM] [Info] Number of data points in the train set: 6297, number of used features: 21\n",
      "[LightGBM] [Info] Start training from score 0.000421\n",
      "[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.000727 seconds.\n",
      "You can set `force_col_wise=true` to remove the overhead.\n",
      "[LightGBM] [Info] Total Bins 5355\n",
      "[LightGBM] [Info] Number of data points in the train set: 6296, number of used features: 21\n",
      "[LightGBM] [Info] Start training from score -0.000248\n",
      "[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.000761 seconds.\n",
      "You can set `force_col_wise=true` to remove the overhead.\n",
      "[LightGBM] [Info] Total Bins 5355\n",
      "[LightGBM] [Info] Number of data points in the train set: 6297, number of used features: 21\n",
      "[LightGBM] [Info] Start training from score -0.000156\n",
      "[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.000645 seconds.\n",
      "You can set `force_col_wise=true` to remove the overhead.\n",
      "[LightGBM] [Info] Total Bins 5355\n",
      "[LightGBM] [Info] Number of data points in the train set: 6297, number of used features: 21\n",
      "[LightGBM] [Info] Start training from score 0.000421\n",
      "[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.002447 seconds.\n",
      "You can set `force_col_wise=true` to remove the overhead.\n",
      "[LightGBM] [Info] Total Bins 5355\n",
      "[LightGBM] [Info] Number of data points in the train set: 6296, number of used features: 21\n",
      "[LightGBM] [Info] Start training from score -0.000248\n",
      "[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.004566 seconds.\n",
      "You can set `force_col_wise=true` to remove the overhead.\n",
      "[LightGBM] [Info] Total Bins 5355\n",
      "[LightGBM] [Info] Number of data points in the train set: 6297, number of used features: 21\n",
      "[LightGBM] [Info] Start training from score -0.000156\n",
      "[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.000740 seconds.\n",
      "You can set `force_col_wise=true` to remove the overhead.\n",
      "[LightGBM] [Info] Total Bins 5355\n",
      "[LightGBM] [Info] Number of data points in the train set: 6297, number of used features: 21\n",
      "[LightGBM] [Info] Start training from score 0.000421\n",
      "[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.001560 seconds.\n",
      "You can set `force_col_wise=true` to remove the overhead.\n",
      "[LightGBM] [Info] Total Bins 5355\n",
      "[LightGBM] [Info] Number of data points in the train set: 6296, number of used features: 21\n",
      "[LightGBM] [Info] Start training from score -0.000248\n",
      "[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.000815 seconds.\n",
      "You can set `force_col_wise=true` to remove the overhead.\n",
      "[LightGBM] [Info] Total Bins 5355\n",
      "[LightGBM] [Info] Number of data points in the train set: 6297, number of used features: 21\n",
      "[LightGBM] [Info] Start training from score -0.000156\n",
      "[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.000764 seconds.\n",
      "You can set `force_col_wise=true` to remove the overhead.\n",
      "[LightGBM] [Info] Total Bins 5355\n",
      "[LightGBM] [Info] Number of data points in the train set: 6297, number of used features: 21\n",
      "[LightGBM] [Info] Start training from score 0.000421\n",
      "[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.001078 seconds.\n",
      "You can set `force_col_wise=true` to remove the overhead.\n",
      "[LightGBM] [Info] Total Bins 5355\n",
      "[LightGBM] [Info] Number of data points in the train set: 9445, number of used features: 21\n",
      "[LightGBM] [Info] Start training from score 0.000005\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<style>#sk-container-id-1 {\n",
       "  /* Definition of color scheme common for light and dark mode */\n",
       "  --sklearn-color-text: #000;\n",
       "  --sklearn-color-text-muted: #666;\n",
       "  --sklearn-color-line: gray;\n",
       "  /* Definition of color scheme for unfitted estimators */\n",
       "  --sklearn-color-unfitted-level-0: #fff5e6;\n",
       "  --sklearn-color-unfitted-level-1: #f6e4d2;\n",
       "  --sklearn-color-unfitted-level-2: #ffe0b3;\n",
       "  --sklearn-color-unfitted-level-3: chocolate;\n",
       "  /* Definition of color scheme for fitted estimators */\n",
       "  --sklearn-color-fitted-level-0: #f0f8ff;\n",
       "  --sklearn-color-fitted-level-1: #d4ebff;\n",
       "  --sklearn-color-fitted-level-2: #b3dbfd;\n",
       "  --sklearn-color-fitted-level-3: cornflowerblue;\n",
       "\n",
       "  /* Specific color for light theme */\n",
       "  --sklearn-color-text-on-default-background: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, black)));\n",
       "  --sklearn-color-background: var(--sg-background-color, var(--theme-background, var(--jp-layout-color0, white)));\n",
       "  --sklearn-color-border-box: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, black)));\n",
       "  --sklearn-color-icon: #696969;\n",
       "\n",
       "  @media (prefers-color-scheme: dark) {\n",
       "    /* Redefinition of color scheme for dark theme */\n",
       "    --sklearn-color-text-on-default-background: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, white)));\n",
       "    --sklearn-color-background: var(--sg-background-color, var(--theme-background, var(--jp-layout-color0, #111)));\n",
       "    --sklearn-color-border-box: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, white)));\n",
       "    --sklearn-color-icon: #878787;\n",
       "  }\n",
       "}\n",
       "\n",
       "#sk-container-id-1 {\n",
       "  color: var(--sklearn-color-text);\n",
       "}\n",
       "\n",
       "#sk-container-id-1 pre {\n",
       "  padding: 0;\n",
       "}\n",
       "\n",
       "#sk-container-id-1 input.sk-hidden--visually {\n",
       "  border: 0;\n",
       "  clip: rect(1px 1px 1px 1px);\n",
       "  clip: rect(1px, 1px, 1px, 1px);\n",
       "  height: 1px;\n",
       "  margin: -1px;\n",
       "  overflow: hidden;\n",
       "  padding: 0;\n",
       "  position: absolute;\n",
       "  width: 1px;\n",
       "}\n",
       "\n",
       "#sk-container-id-1 div.sk-dashed-wrapped {\n",
       "  border: 1px dashed var(--sklearn-color-line);\n",
       "  margin: 0 0.4em 0.5em 0.4em;\n",
       "  box-sizing: border-box;\n",
       "  padding-bottom: 0.4em;\n",
       "  background-color: var(--sklearn-color-background);\n",
       "}\n",
       "\n",
       "#sk-container-id-1 div.sk-container {\n",
       "  /* jupyter's `normalize.less` sets `[hidden] { display: none; }`\n",
       "     but bootstrap.min.css set `[hidden] { display: none !important; }`\n",
       "     so we also need the `!important` here to be able to override the\n",
       "     default hidden behavior on the sphinx rendered scikit-learn.org.\n",
       "     See: https://github.com/scikit-learn/scikit-learn/issues/21755 */\n",
       "  display: inline-block !important;\n",
       "  position: relative;\n",
       "}\n",
       "\n",
       "#sk-container-id-1 div.sk-text-repr-fallback {\n",
       "  display: none;\n",
       "}\n",
       "\n",
       "div.sk-parallel-item,\n",
       "div.sk-serial,\n",
       "div.sk-item {\n",
       "  /* draw centered vertical line to link estimators */\n",
       "  background-image: linear-gradient(var(--sklearn-color-text-on-default-background), var(--sklearn-color-text-on-default-background));\n",
       "  background-size: 2px 100%;\n",
       "  background-repeat: no-repeat;\n",
       "  background-position: center center;\n",
       "}\n",
       "\n",
       "/* Parallel-specific style estimator block */\n",
       "\n",
       "#sk-container-id-1 div.sk-parallel-item::after {\n",
       "  content: \"\";\n",
       "  width: 100%;\n",
       "  border-bottom: 2px solid var(--sklearn-color-text-on-default-background);\n",
       "  flex-grow: 1;\n",
       "}\n",
       "\n",
       "#sk-container-id-1 div.sk-parallel {\n",
       "  display: flex;\n",
       "  align-items: stretch;\n",
       "  justify-content: center;\n",
       "  background-color: var(--sklearn-color-background);\n",
       "  position: relative;\n",
       "}\n",
       "\n",
       "#sk-container-id-1 div.sk-parallel-item {\n",
       "  display: flex;\n",
       "  flex-direction: column;\n",
       "}\n",
       "\n",
       "#sk-container-id-1 div.sk-parallel-item:first-child::after {\n",
       "  align-self: flex-end;\n",
       "  width: 50%;\n",
       "}\n",
       "\n",
       "#sk-container-id-1 div.sk-parallel-item:last-child::after {\n",
       "  align-self: flex-start;\n",
       "  width: 50%;\n",
       "}\n",
       "\n",
       "#sk-container-id-1 div.sk-parallel-item:only-child::after {\n",
       "  width: 0;\n",
       "}\n",
       "\n",
       "/* Serial-specific style estimator block */\n",
       "\n",
       "#sk-container-id-1 div.sk-serial {\n",
       "  display: flex;\n",
       "  flex-direction: column;\n",
       "  align-items: center;\n",
       "  background-color: var(--sklearn-color-background);\n",
       "  padding-right: 1em;\n",
       "  padding-left: 1em;\n",
       "}\n",
       "\n",
       "\n",
       "/* Toggleable style: style used for estimator/Pipeline/ColumnTransformer box that is\n",
       "clickable and can be expanded/collapsed.\n",
       "- Pipeline and ColumnTransformer use this feature and define the default style\n",
       "- Estimators will overwrite some part of the style using the `sk-estimator` class\n",
       "*/\n",
       "\n",
       "/* Pipeline and ColumnTransformer style (default) */\n",
       "\n",
       "#sk-container-id-1 div.sk-toggleable {\n",
       "  /* Default theme specific background. It is overwritten whether we have a\n",
       "  specific estimator or a Pipeline/ColumnTransformer */\n",
       "  background-color: var(--sklearn-color-background);\n",
       "}\n",
       "\n",
       "/* Toggleable label */\n",
       "#sk-container-id-1 label.sk-toggleable__label {\n",
       "  cursor: pointer;\n",
       "  display: flex;\n",
       "  width: 100%;\n",
       "  margin-bottom: 0;\n",
       "  padding: 0.5em;\n",
       "  box-sizing: border-box;\n",
       "  text-align: center;\n",
       "  align-items: start;\n",
       "  justify-content: space-between;\n",
       "  gap: 0.5em;\n",
       "}\n",
       "\n",
       "#sk-container-id-1 label.sk-toggleable__label .caption {\n",
       "  font-size: 0.6rem;\n",
       "  font-weight: lighter;\n",
       "  color: var(--sklearn-color-text-muted);\n",
       "}\n",
       "\n",
       "#sk-container-id-1 label.sk-toggleable__label-arrow:before {\n",
       "  /* Arrow on the left of the label */\n",
       "  content: \"▸\";\n",
       "  float: left;\n",
       "  margin-right: 0.25em;\n",
       "  color: var(--sklearn-color-icon);\n",
       "}\n",
       "\n",
       "#sk-container-id-1 label.sk-toggleable__label-arrow:hover:before {\n",
       "  color: var(--sklearn-color-text);\n",
       "}\n",
       "\n",
       "/* Toggleable content - dropdown */\n",
       "\n",
       "#sk-container-id-1 div.sk-toggleable__content {\n",
       "  max-height: 0;\n",
       "  max-width: 0;\n",
       "  overflow: hidden;\n",
       "  text-align: left;\n",
       "  /* unfitted */\n",
       "  background-color: var(--sklearn-color-unfitted-level-0);\n",
       "}\n",
       "\n",
       "#sk-container-id-1 div.sk-toggleable__content.fitted {\n",
       "  /* fitted */\n",
       "  background-color: var(--sklearn-color-fitted-level-0);\n",
       "}\n",
       "\n",
       "#sk-container-id-1 div.sk-toggleable__content pre {\n",
       "  margin: 0.2em;\n",
       "  border-radius: 0.25em;\n",
       "  color: var(--sklearn-color-text);\n",
       "  /* unfitted */\n",
       "  background-color: var(--sklearn-color-unfitted-level-0);\n",
       "}\n",
       "\n",
       "#sk-container-id-1 div.sk-toggleable__content.fitted pre {\n",
       "  /* unfitted */\n",
       "  background-color: var(--sklearn-color-fitted-level-0);\n",
       "}\n",
       "\n",
       "#sk-container-id-1 input.sk-toggleable__control:checked~div.sk-toggleable__content {\n",
       "  /* Expand drop-down */\n",
       "  max-height: 200px;\n",
       "  max-width: 100%;\n",
       "  overflow: auto;\n",
       "}\n",
       "\n",
       "#sk-container-id-1 input.sk-toggleable__control:checked~label.sk-toggleable__label-arrow:before {\n",
       "  content: \"▾\";\n",
       "}\n",
       "\n",
       "/* Pipeline/ColumnTransformer-specific style */\n",
       "\n",
       "#sk-container-id-1 div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
       "  color: var(--sklearn-color-text);\n",
       "  background-color: var(--sklearn-color-unfitted-level-2);\n",
       "}\n",
       "\n",
       "#sk-container-id-1 div.sk-label.fitted input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
       "  background-color: var(--sklearn-color-fitted-level-2);\n",
       "}\n",
       "\n",
       "/* Estimator-specific style */\n",
       "\n",
       "/* Colorize estimator box */\n",
       "#sk-container-id-1 div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
       "  /* unfitted */\n",
       "  background-color: var(--sklearn-color-unfitted-level-2);\n",
       "}\n",
       "\n",
       "#sk-container-id-1 div.sk-estimator.fitted input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
       "  /* fitted */\n",
       "  background-color: var(--sklearn-color-fitted-level-2);\n",
       "}\n",
       "\n",
       "#sk-container-id-1 div.sk-label label.sk-toggleable__label,\n",
       "#sk-container-id-1 div.sk-label label {\n",
       "  /* The background is the default theme color */\n",
       "  color: var(--sklearn-color-text-on-default-background);\n",
       "}\n",
       "\n",
       "/* On hover, darken the color of the background */\n",
       "#sk-container-id-1 div.sk-label:hover label.sk-toggleable__label {\n",
       "  color: var(--sklearn-color-text);\n",
       "  background-color: var(--sklearn-color-unfitted-level-2);\n",
       "}\n",
       "\n",
       "/* Label box, darken color on hover, fitted */\n",
       "#sk-container-id-1 div.sk-label.fitted:hover label.sk-toggleable__label.fitted {\n",
       "  color: var(--sklearn-color-text);\n",
       "  background-color: var(--sklearn-color-fitted-level-2);\n",
       "}\n",
       "\n",
       "/* Estimator label */\n",
       "\n",
       "#sk-container-id-1 div.sk-label label {\n",
       "  font-family: monospace;\n",
       "  font-weight: bold;\n",
       "  display: inline-block;\n",
       "  line-height: 1.2em;\n",
       "}\n",
       "\n",
       "#sk-container-id-1 div.sk-label-container {\n",
       "  text-align: center;\n",
       "}\n",
       "\n",
       "/* Estimator-specific */\n",
       "#sk-container-id-1 div.sk-estimator {\n",
       "  font-family: monospace;\n",
       "  border: 1px dotted var(--sklearn-color-border-box);\n",
       "  border-radius: 0.25em;\n",
       "  box-sizing: border-box;\n",
       "  margin-bottom: 0.5em;\n",
       "  /* unfitted */\n",
       "  background-color: var(--sklearn-color-unfitted-level-0);\n",
       "}\n",
       "\n",
       "#sk-container-id-1 div.sk-estimator.fitted {\n",
       "  /* fitted */\n",
       "  background-color: var(--sklearn-color-fitted-level-0);\n",
       "}\n",
       "\n",
       "/* on hover */\n",
       "#sk-container-id-1 div.sk-estimator:hover {\n",
       "  /* unfitted */\n",
       "  background-color: var(--sklearn-color-unfitted-level-2);\n",
       "}\n",
       "\n",
       "#sk-container-id-1 div.sk-estimator.fitted:hover {\n",
       "  /* fitted */\n",
       "  background-color: var(--sklearn-color-fitted-level-2);\n",
       "}\n",
       "\n",
       "/* Specification for estimator info (e.g. \"i\" and \"?\") */\n",
       "\n",
       "/* Common style for \"i\" and \"?\" */\n",
       "\n",
       ".sk-estimator-doc-link,\n",
       "a:link.sk-estimator-doc-link,\n",
       "a:visited.sk-estimator-doc-link {\n",
       "  float: right;\n",
       "  font-size: smaller;\n",
       "  line-height: 1em;\n",
       "  font-family: monospace;\n",
       "  background-color: var(--sklearn-color-background);\n",
       "  border-radius: 1em;\n",
       "  height: 1em;\n",
       "  width: 1em;\n",
       "  text-decoration: none !important;\n",
       "  margin-left: 0.5em;\n",
       "  text-align: center;\n",
       "  /* unfitted */\n",
       "  border: var(--sklearn-color-unfitted-level-1) 1pt solid;\n",
       "  color: var(--sklearn-color-unfitted-level-1);\n",
       "}\n",
       "\n",
       ".sk-estimator-doc-link.fitted,\n",
       "a:link.sk-estimator-doc-link.fitted,\n",
       "a:visited.sk-estimator-doc-link.fitted {\n",
       "  /* fitted */\n",
       "  border: var(--sklearn-color-fitted-level-1) 1pt solid;\n",
       "  color: var(--sklearn-color-fitted-level-1);\n",
       "}\n",
       "\n",
       "/* On hover */\n",
       "div.sk-estimator:hover .sk-estimator-doc-link:hover,\n",
       ".sk-estimator-doc-link:hover,\n",
       "div.sk-label-container:hover .sk-estimator-doc-link:hover,\n",
       ".sk-estimator-doc-link:hover {\n",
       "  /* unfitted */\n",
       "  background-color: var(--sklearn-color-unfitted-level-3);\n",
       "  color: var(--sklearn-color-background);\n",
       "  text-decoration: none;\n",
       "}\n",
       "\n",
       "div.sk-estimator.fitted:hover .sk-estimator-doc-link.fitted:hover,\n",
       ".sk-estimator-doc-link.fitted:hover,\n",
       "div.sk-label-container:hover .sk-estimator-doc-link.fitted:hover,\n",
       ".sk-estimator-doc-link.fitted:hover {\n",
       "  /* fitted */\n",
       "  background-color: var(--sklearn-color-fitted-level-3);\n",
       "  color: var(--sklearn-color-background);\n",
       "  text-decoration: none;\n",
       "}\n",
       "\n",
       "/* Span, style for the box shown on hovering the info icon */\n",
       ".sk-estimator-doc-link span {\n",
       "  display: none;\n",
       "  z-index: 9999;\n",
       "  position: relative;\n",
       "  font-weight: normal;\n",
       "  right: .2ex;\n",
       "  padding: .5ex;\n",
       "  margin: .5ex;\n",
       "  width: min-content;\n",
       "  min-width: 20ex;\n",
       "  max-width: 50ex;\n",
       "  color: var(--sklearn-color-text);\n",
       "  box-shadow: 2pt 2pt 4pt #999;\n",
       "  /* unfitted */\n",
       "  background: var(--sklearn-color-unfitted-level-0);\n",
       "  border: .5pt solid var(--sklearn-color-unfitted-level-3);\n",
       "}\n",
       "\n",
       ".sk-estimator-doc-link.fitted span {\n",
       "  /* fitted */\n",
       "  background: var(--sklearn-color-fitted-level-0);\n",
       "  border: var(--sklearn-color-fitted-level-3);\n",
       "}\n",
       "\n",
       ".sk-estimator-doc-link:hover span {\n",
       "  display: block;\n",
       "}\n",
       "\n",
       "/* \"?\"-specific style due to the `<a>` HTML tag */\n",
       "\n",
       "#sk-container-id-1 a.estimator_doc_link {\n",
       "  float: right;\n",
       "  font-size: 1rem;\n",
       "  line-height: 1em;\n",
       "  font-family: monospace;\n",
       "  background-color: var(--sklearn-color-background);\n",
       "  border-radius: 1rem;\n",
       "  height: 1rem;\n",
       "  width: 1rem;\n",
       "  text-decoration: none;\n",
       "  /* unfitted */\n",
       "  color: var(--sklearn-color-unfitted-level-1);\n",
       "  border: var(--sklearn-color-unfitted-level-1) 1pt solid;\n",
       "}\n",
       "\n",
       "#sk-container-id-1 a.estimator_doc_link.fitted {\n",
       "  /* fitted */\n",
       "  border: var(--sklearn-color-fitted-level-1) 1pt solid;\n",
       "  color: var(--sklearn-color-fitted-level-1);\n",
       "}\n",
       "\n",
       "/* On hover */\n",
       "#sk-container-id-1 a.estimator_doc_link:hover {\n",
       "  /* unfitted */\n",
       "  background-color: var(--sklearn-color-unfitted-level-3);\n",
       "  color: var(--sklearn-color-background);\n",
       "  text-decoration: none;\n",
       "}\n",
       "\n",
       "#sk-container-id-1 a.estimator_doc_link.fitted:hover {\n",
       "  /* fitted */\n",
       "  background-color: var(--sklearn-color-fitted-level-3);\n",
       "}\n",
       "</style><div id=\"sk-container-id-1\" class=\"sk-top-container\"><div class=\"sk-text-repr-fallback\"><pre>GridSearchCV(cv=3, estimator=LGBMRegressor(objective=&#x27;regression&#x27;),\n",
       "             param_grid={&#x27;learning_rate&#x27;: [0.05, 0.1, 0.15],\n",
       "                         &#x27;n_estimators&#x27;: [20, 30, 50], &#x27;num_leaves&#x27;: [31, 50]})</pre><b>In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. <br />On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.</b></div><div class=\"sk-container\" hidden><div class=\"sk-item sk-dashed-wrapped\"><div class=\"sk-label-container\"><div class=\"sk-label fitted sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-1\" type=\"checkbox\" ><label for=\"sk-estimator-id-1\" class=\"sk-toggleable__label fitted sk-toggleable__label-arrow\"><div><div>GridSearchCV</div></div><div><a class=\"sk-estimator-doc-link fitted\" rel=\"noreferrer\" target=\"_blank\" href=\"https://scikit-learn.org/1.6/modules/generated/sklearn.model_selection.GridSearchCV.html\">?<span>Documentation for GridSearchCV</span></a><span class=\"sk-estimator-doc-link fitted\">i<span>Fitted</span></span></div></label><div class=\"sk-toggleable__content fitted\"><pre>GridSearchCV(cv=3, estimator=LGBMRegressor(objective=&#x27;regression&#x27;),\n",
       "             param_grid={&#x27;learning_rate&#x27;: [0.05, 0.1, 0.15],\n",
       "                         &#x27;n_estimators&#x27;: [20, 30, 50], &#x27;num_leaves&#x27;: [31, 50]})</pre></div> </div></div><div class=\"sk-parallel\"><div class=\"sk-parallel-item\"><div class=\"sk-item\"><div class=\"sk-label-container\"><div class=\"sk-label fitted sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-2\" type=\"checkbox\" ><label for=\"sk-estimator-id-2\" class=\"sk-toggleable__label fitted sk-toggleable__label-arrow\"><div><div>best_estimator_: LGBMRegressor</div></div></label><div class=\"sk-toggleable__content fitted\"><pre>LGBMRegressor(learning_rate=0.05, n_estimators=20, objective=&#x27;regression&#x27;)</pre></div> </div></div><div class=\"sk-serial\"><div class=\"sk-item\"><div class=\"sk-estimator fitted sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-3\" type=\"checkbox\" ><label for=\"sk-estimator-id-3\" class=\"sk-toggleable__label fitted sk-toggleable__label-arrow\"><div><div>LGBMRegressor</div></div></label><div class=\"sk-toggleable__content fitted\"><pre>LGBMRegressor(learning_rate=0.05, n_estimators=20, objective=&#x27;regression&#x27;)</pre></div> </div></div></div></div></div></div></div></div></div>"
      ],
      "text/plain": [
       "GridSearchCV(cv=3, estimator=LGBMRegressor(objective='regression'),\n",
       "             param_grid={'learning_rate': [0.05, 0.1, 0.15],\n",
       "                         'n_estimators': [20, 30, 50], 'num_leaves': [31, 50]})"
      ]
     },
     "execution_count": 78,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "param_grid = {\n",
    "    'num_leaves': [31, 50],\n",
    "    'learning_rate': [0.05, 0.1, 0.15],\n",
    "    'n_estimators': [20, 30, 50]\n",
    "}\n",
    "gbm = lgb.LGBMRegressor(objective='regression')\n",
    "grid_search = GridSearchCV(gbm, param_grid, cv=3)\n",
    "grid_search.fit(x1,y1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 79,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'learning_rate': 0.05, 'n_estimators': 20, 'num_leaves': 31}\n"
     ]
    }
   ],
   "source": [
    "print(grid_search.best_params_)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 80,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "-0.004853512439754848\n"
     ]
    }
   ],
   "source": [
    "best_model = grid_search.best_estimator_\n",
    "y_pred = best_model.predict(x2)\n",
    "print(r2_score(y2,y_pred))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 81,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "-0.0115241069145251\n"
     ]
    }
   ],
   "source": [
    "xy_t=df[(df.date>datetime.datetime(2024,1,1))].iloc[:,2:].dropna()\n",
    "xy_x=xy_t.drop(\"return_s1\",axis=1)\n",
    "xy_y=xy_t[\"return_s1\"]\n",
    "y_pred = best_model.predict(xy_x)\n",
    "print(r2_score(xy_y,y_pred))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 82,
   "metadata": {},
   "outputs": [],
   "source": [
    "y_pred=pd.Series(y_pred,index=xy_x.index,name=\"pred\")\n",
    "pyb_data_pe=df[df.date>datetime.datetime(2024,1,1)].iloc[:,0:7]\n",
    "pyb_data_pe = pd.concat([pyb_data_pe, y_pred], axis=1)\n",
    "pyb_data_pe.fillna(0,inplace=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 86,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Backtesting: 2024-01-01 00:00:00 to 2024-05-01 00:00:00\n",
      "\n",
      "Test split: 2024-01-02 00:00:00 to 2024-04-30 00:00:00\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0% (0 of 78) |                         | Elapsed Time: 0:00:00 ETA:  --:--:--\n",
      "  1% (1 of 78) |                         | Elapsed Time: 0:00:00 ETA:   0:00:07\n",
      " 14% (11 of 78) |###                     | Elapsed Time: 0:00:00 ETA:   0:00:01\n",
      " 26% (21 of 78) |######                  | Elapsed Time: 0:00:00 ETA:   0:00:01\n",
      " 39% (31 of 78) |#########               | Elapsed Time: 0:00:00 ETA:   0:00:00\n",
      " 52% (41 of 78) |############            | Elapsed Time: 0:00:00 ETA:   0:00:00\n",
      " 65% (51 of 78) |###############         | Elapsed Time: 0:00:00 ETA:   0:00:00\n",
      " 78% (61 of 78) |##################      | Elapsed Time: 0:00:00 ETA:   0:00:00\n",
      " 91% (71 of 78) |#####################   | Elapsed Time: 0:00:01 ETA:   0:00:00\n",
      "100% (78 of 78) |########################| Elapsed Time: 0:00:01 ETA:  00:00:00\n",
      "100% (78 of 78) |########################| Elapsed Time: 0:00:01 Time:  0:00:01\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Finished backtest: 0:00:01\n"
     ]
    }
   ],
   "source": [
    "pybroker.register_columns('pred')\n",
    "config = StrategyConfig(initial_cash=10000000)\n",
    "strategy = Strategy(pyb_data_pe, '2024-01-01', '2024-05-01',config)\n",
    "strategy.add_execution(hold_long, pyb_data_pe.symbol.unique())\n",
    "result = strategy.backtest()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 87,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>name</th>\n",
       "      <th>value</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>trade_count</td>\n",
       "      <td>510.00</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>initial_market_value</td>\n",
       "      <td>10000000.00</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>end_market_value</td>\n",
       "      <td>9949515.08</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>total_pnl</td>\n",
       "      <td>-73616.58</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>unrealized_pnl</td>\n",
       "      <td>23131.66</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                   name        value\n",
       "0           trade_count       510.00\n",
       "1  initial_market_value  10000000.00\n",
       "2      end_market_value   9949515.08\n",
       "3             total_pnl    -73616.58\n",
       "4        unrealized_pnl     23131.66"
      ]
     },
     "execution_count": 87,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "result.metrics_df.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "quant",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
